Skip to content

Commit 80e8cf0

Browse files
sunayanagsginji
authored andcommitted
fix: ensure Nova and Anthropic models start with user message in chat
- Added _is_nova_model method to detect Nova models - Updated _format_messages_for_provider to ensure conversations start with user message - Added unit tests for Nova model detection and message formatting - Fixes #3298
1 parent 0f1b764 commit 80e8cf0

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

src/crewai/llm.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def __init__(
339339
self.reasoning_effort = reasoning_effort
340340
self.additional_params = kwargs
341341
self.is_anthropic = self._is_anthropic_model(model)
342+
self.is_nova = self._is_nova_model(model)
342343
self.stream = stream
343344

344345
litellm.drop_params = True
@@ -354,6 +355,17 @@ def __init__(
354355
self.set_callbacks(callbacks)
355356
self.set_env_callbacks()
356357

358+
def _is_nova_model(self, model: str) -> bool:
359+
"""Determine if the model is an Amazon Nova model.
360+
361+
Args:
362+
model: The model identifier string.
363+
364+
Returns:
365+
bool: True if the model is a Nova model, False otherwise.
366+
"""
367+
return "amazon.nova-" in model.lower()
368+
357369
def _is_anthropic_model(self, model: str) -> bool:
358370
"""Determine if the model is from Anthropic provider.
359371
@@ -366,6 +378,20 @@ def _is_anthropic_model(self, model: str) -> bool:
366378
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
367379
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
368380

381+
def _ensure_starts_with_user_message(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
382+
"""Ensure messages list starts with a user message.
383+
384+
Args:
385+
messages: List of message dictionaries
386+
387+
Returns:
388+
List of messages with a user message at the start if needed
389+
"""
390+
# Check if first message is system (or empty list)
391+
if not messages or messages[0]["role"] == "system":
392+
return [{"role": "user", "content": "."}, *messages]
393+
return messages
394+
369395
def _prepare_completion_params(
370396
self,
371397
messages: Union[str, List[Dict[str, str]]],
@@ -1157,15 +1183,9 @@ def _format_messages_for_provider(
11571183
):
11581184
return messages + [{"role": "user", "content": ""}]
11591185

1160-
# Handle Anthropic models
1161-
if not self.is_anthropic:
1162-
return messages
1163-
1164-
# Anthropic requires messages to start with 'user' role
1165-
if not messages or messages[0]["role"] == "system":
1166-
# If first message is system or empty, add a placeholder user message
1167-
return [{"role": "user", "content": "."}, *messages]
1168-
1186+
# Both Nova and Anthropic require the conversation to start with a user message
1187+
if self.is_nova or self.is_anthropic:
1188+
return self._ensure_starts_with_user_message(messages)
11691189
return messages
11701190

11711191
def _get_custom_llm_provider(self) -> Optional[str]:

tests/test_llm.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,31 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
712712
formatted = ollama_llm._format_messages_for_provider(original_messages)
713713

714714
assert formatted == original_messages
715+
716+
@pytest.fixture
717+
def nova_llm():
718+
return LLM(model="bedrock/us.amazon.nova-pro-v1:0")
719+
720+
def test_nova_model_detection(nova_llm):
721+
assert nova_llm.is_nova
722+
assert LLM(model="bedrock/amazon.nova-lite-v1:0").is_nova
723+
assert not LLM(model="gpt-4").is_nova
724+
725+
def test_nova_message_formatting(nova_llm):
726+
# Should add user message at start if only system messages
727+
messages = [
728+
{"role": "system", "content": "System message"},
729+
{"role": "assistant", "content": "Assistant message"}
730+
]
731+
formatted = nova_llm._format_messages_for_provider(messages)
732+
assert formatted[0]["role"] == "user"
733+
assert len(formatted) == len(messages) + 1
734+
735+
# Should not modify if already has user message at start
736+
messages = [
737+
{"role": "user", "content": "User message"},
738+
{"role": "system", "content": "System message"},
739+
{"role": "assistant", "content": "Assistant message"}
740+
]
741+
formatted = nova_llm._format_messages_for_provider(messages)
742+
assert formatted == messages

0 commit comments

Comments
 (0)