Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)
from ._multi_chat_completion_context import MultiChatCompletionContext
from ._merge_system_chat_completion_context import MergeSystemChatCompletionContext

__all__ = [
"ChatCompletionContext",
Expand All @@ -13,4 +15,6 @@
"BufferedChatCompletionContext",
"TokenLimitedChatCompletionContext",
"HeadAndTailChatCompletionContext",
"MultiChatCompletionContext",
"MergeSystemChatCompletionContext",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List
from ..model_context import ChatCompletionContext
from ..models import LLMMessage, SystemMessage


class MergeSystemChatCompletionContext(ChatCompletionContext):
"""
A `ChatCompletionContext` that merges multiple `SystemMessage`s into one.

This is useful for models that **do not support multiple system prompts**,
by collapsing all system messages into a single `SystemMessage` at the
beginning of the conversation.

Additionally, this context removes the `thought` field if present.

Example:
.. code-block:: python

from autogen_core.model_context import MergeSystemChatCompletionContext
from autogen_core.models import SystemMessage, UserMessage

ctx = MergeSystemChatCompletionContext()
await ctx.add_message(SystemMessage(content="System rule 1"))
await ctx.add_message(SystemMessage(content="System rule 2"))
await ctx.add_message(UserMessage(content="Hello!"))

merged = await ctx.get_messages()
# merged[0] => SystemMessage("System rule 1\nSystem rule 2")
"""

async def get_messages(self) -> List[LLMMessage]:
messages = self._messages
merged_system_content = []
messages_out: List[LLMMessage] = []

for message in messages:
if isinstance(message, SystemMessage):
merged_system_content.append(message.content)
else:
messages_out.append(message)

if merged_system_content:
merged_system_message = SystemMessage(content="\n".join(merged_system_content))
messages_out.insert(0, merged_system_message)

return messages_out
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List
from ._chat_completion_context import ChatCompletionContext
from ..models import LLMMessage


class MultiChatCompletionContext(ChatCompletionContext):
"""
A wrapper context that chains multiple `ChatCompletionContext` objects.

This allows combining multiple message management strategies (e.g.
`UnboundedChatCompletionContext`, `TokenLimitedChatCompletionContext`) into
a single unified context that can be passed to an `AssistantAgent`.

Each context in the chain will process the message list sequentially.

Example:
.. code-block:: python

from autogen_core.model_context import (
MultiChatCompletionContext,
UnboundedChatCompletionContext,
TokenLimitedChatCompletionContext,
)

ctx = MultiChatCompletionContext([
UnboundedChatCompletionContext(),
TokenLimitedChatCompletionContext(max_tokens=4096),
])
messages = await ctx.get_messages()
"""

def __init__(self, contexts: List[ChatCompletionContext]):
super().__init__()
self._contexts = contexts

async def get_messages(self) -> List[LLMMessage]:
messages = self._messages
for ctx in self._contexts:
ctx._messages = messages
messages = await ctx.get_messages()
return messages
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from typing import List
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
MultiChatCompletionContext,
MergeSystemChatCompletionContext,
TokenLimitedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
SystemMessage
)

@pytest.mark.asyncio
async def test_multi_chat_completion_context_combines_contexts():
ctx1 = UnboundedChatCompletionContext()
ctx2 = TokenLimitedChatCompletionContext(20)

messages: List[LLMMessage] = [
UserMessage(content="Hello!", source="user"),
AssistantMessage(content="What can I do for you?", source="assistant"),
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
]

multi_ctx = MultiChatCompletionContext([ctx1, ctx2])

messages = await multi_ctx.get_messages()
assert len(messages) == 2
assert messages[0].content == "Hello!"
assert messages[1].content == "What can I do for you?"


@pytest.mark.asyncio
async def test_merge_system_chat_completion_context_merges_system_messages():
merge_ctx = MergeSystemChatCompletionContext()
messages = [
SystemMessage(content="Rule 1: Be polite."),
SystemMessage(content="Rule 2: Respond clearly."),
UserMessage(content="What’s your name?", source="user"),
]

for msg in messages:
await merge_ctx.add_message(msg)

merged = await merge_ctx.get_messages()

assert len(merged) == 2
assert isinstance(merged[0], SystemMessage)
assert "Rule 1" in merged[0].content and "Rule 2" in merged[0].content

# The user message should still be present
assert isinstance(merged[1], UserMessage)
assert merged[1].content == "What’s your name?"