diff --git a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py index b6898614ec37..f6cd781a2632 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py @@ -5,6 +5,8 @@ from ._unbounded_chat_completion_context import ( UnboundedChatCompletionContext, ) +from ._multi_chat_completion_context import MultiChatCompletionContext +from ._merge_system_chat_completion_context import MergeSystemChatCompletionContext __all__ = [ "ChatCompletionContext", @@ -13,4 +15,6 @@ "BufferedChatCompletionContext", "TokenLimitedChatCompletionContext", "HeadAndTailChatCompletionContext", + "MultiChatCompletionContext", + "MergeSystemChatCompletionContext", ] diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_merge_system_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_merge_system_chat_completion_context.py new file mode 100644 index 000000000000..44edddf42828 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/model_context/_merge_system_chat_completion_context.py @@ -0,0 +1,46 @@ +from typing import List +from ..model_context import ChatCompletionContext +from ..models import LLMMessage, SystemMessage + + +class MergeSystemChatCompletionContext(ChatCompletionContext): + """ + A `ChatCompletionContext` that merges multiple `SystemMessage`s into one. + + This is useful for models that **do not support multiple system prompts**, + by collapsing all system messages into a single `SystemMessage` at the + beginning of the conversation. + + Additionally, this context removes the `thought` field if present. + + Example: + .. code-block:: python + + from autogen_core.model_context import MergeSystemChatCompletionContext + from autogen_core.models import SystemMessage, UserMessage + + ctx = MergeSystemChatCompletionContext() + await ctx.add_message(SystemMessage(content="System rule 1")) + await ctx.add_message(SystemMessage(content="System rule 2")) + await ctx.add_message(UserMessage(content="Hello!")) + + merged = await ctx.get_messages() + # merged[0] => SystemMessage("System rule 1\nSystem rule 2") + """ + + async def get_messages(self) -> List[LLMMessage]: + messages = self._messages + merged_system_content = [] + messages_out: List[LLMMessage] = [] + + for message in messages: + if isinstance(message, SystemMessage): + merged_system_content.append(message.content) + else: + messages_out.append(message) + + if merged_system_content: + merged_system_message = SystemMessage(content="\n".join(merged_system_content)) + messages_out.insert(0, merged_system_message) + + return messages_out diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_multi_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_multi_chat_completion_context.py new file mode 100644 index 000000000000..115b89d6ec27 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/model_context/_multi_chat_completion_context.py @@ -0,0 +1,41 @@ +from typing import List +from ._chat_completion_context import ChatCompletionContext +from ..models import LLMMessage + + +class MultiChatCompletionContext(ChatCompletionContext): + """ + A wrapper context that chains multiple `ChatCompletionContext` objects. + + This allows combining multiple message management strategies (e.g. + `UnboundedChatCompletionContext`, `TokenLimitedChatCompletionContext`) into + a single unified context that can be passed to an `AssistantAgent`. + + Each context in the chain will process the message list sequentially. + + Example: + .. code-block:: python + + from autogen_core.model_context import ( + MultiChatCompletionContext, + UnboundedChatCompletionContext, + TokenLimitedChatCompletionContext, + ) + + ctx = MultiChatCompletionContext([ + UnboundedChatCompletionContext(), + TokenLimitedChatCompletionContext(max_tokens=4096), + ]) + messages = await ctx.get_messages() + """ + + def __init__(self, contexts: List[ChatCompletionContext]): + super().__init__() + self._contexts = contexts + + async def get_messages(self) -> List[LLMMessage]: + messages = self._messages + for ctx in self._contexts: + ctx._messages = messages + messages = await ctx.get_messages() + return messages diff --git a/python/packages/autogen-core/tests/test_multi_and_merge_chat_context.py b/python/packages/autogen-core/tests/test_multi_and_merge_chat_context.py new file mode 100644 index 000000000000..ee4d3dd07291 --- /dev/null +++ b/python/packages/autogen-core/tests/test_multi_and_merge_chat_context.py @@ -0,0 +1,58 @@ +import pytest +from typing import List +from autogen_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, + MultiChatCompletionContext, + MergeSystemChatCompletionContext, + TokenLimitedChatCompletionContext, +) +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + FunctionExecutionResultMessage, + LLMMessage, + UserMessage, + SystemMessage +) + +@pytest.mark.asyncio +async def test_multi_chat_completion_context_combines_contexts(): + ctx1 = UnboundedChatCompletionContext() + ctx2 = TokenLimitedChatCompletionContext(20) + + messages: List[LLMMessage] = [ + UserMessage(content="Hello!", source="user"), + AssistantMessage(content="What can I do for you?", source="assistant"), + UserMessage(content="Tell what are some fun things to do in seattle.", source="user"), + ] + + multi_ctx = MultiChatCompletionContext([ctx1, ctx2]) + + messages = await multi_ctx.get_messages() + assert len(messages) == 2 + assert messages[0].content == "Hello!" + assert messages[1].content == "What can I do for you?" + + +@pytest.mark.asyncio +async def test_merge_system_chat_completion_context_merges_system_messages(): + merge_ctx = MergeSystemChatCompletionContext() + messages = [ + SystemMessage(content="Rule 1: Be polite."), + SystemMessage(content="Rule 2: Respond clearly."), + UserMessage(content="What’s your name?", source="user"), + ] + + for msg in messages: + await merge_ctx.add_message(msg) + + merged = await merge_ctx.get_messages() + + assert len(merged) == 2 + assert isinstance(merged[0], SystemMessage) + assert "Rule 1" in merged[0].content and "Rule 2" in merged[0].content + + # The user message should still be present + assert isinstance(merged[1], UserMessage) + assert merged[1].content == "What’s your name?"