Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update OpenaiClient to Support Deepseek-Reasoning Model #634

Merged
merged 12 commits into from
Jan 23, 2025
48 changes: 48 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,58 @@ def _is_agent_name_error_message(message: str) -> bool:
pattern = re.compile(r"Invalid 'messages\[\d+\]\.name': string does not match pattern.")
return True if pattern.match(message) else False

@staticmethod
def _move_system_message_to_beginning(messages: list[dict[str, Any]]) -> None:
for msg in messages:
if msg["role"] == "system":
messages.insert(0, messages.pop(messages.index(msg)))
break

@staticmethod
def _patch_messages_for_deepseek_reasoner(**kwargs: Any) -> Any:
if (
"model" not in kwargs
or kwargs["model"] != "deepseek-reasoner"
or "messages" not in kwargs
or len(kwargs["messages"]) == 0
):
return kwargs

# The system message of deepseek-reasoner must be put on the beginning of the message sequence.
OpenAIClient._move_system_message_to_beginning(kwargs["messages"])

new_messages = []
previous_role = None
for message in kwargs["messages"]:
if "role" in message:
current_role = message["role"]

# This model requires alternating roles
if current_role == previous_role:
# Swap the role
if current_role == "user":
message["role"] = "assistant"
elif current_role == "assistant":
message["role"] = "user"

previous_role = message["role"]

new_messages.append(message)

# The last message of deepseek-reasoner must be a user message
# , or an assistant message with prefix mode on (but this is supported only for beta api)
if new_messages[-1]["role"] != "user":
new_messages.append({"role": "user", "content": "continue"})

kwargs["messages"] = new_messages

return kwargs

@staticmethod
def _handle_openai_bad_request_error(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any):
try:
kwargs = OpenAIClient._patch_messages_for_deepseek_reasoner(**kwargs)
return func(*args, **kwargs)
except openai.BadRequestError as e:
response_json = e.response.json()
Expand Down
20 changes: 20 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,26 @@ def login(
mock.assert_called_once()


@pytest.mark.deepseek
def test_conversable_agent_with_deepseek_reasoner(
credentials_deepseek_reasoner: Credentials,
) -> None:
agent = ConversableAgent(
name="agent",
llm_config=credentials_deepseek_reasoner.llm_config,
)

user_proxy = UserProxyAgent(
name="user_proxy_1",
human_input_mode="NEVER",
)

result = user_proxy.initiate_chat(
agent, message="Hello, how are you?", summary_method="reflection_with_llm", max_turns=2
)
assert isinstance(result.summary, str)


if __name__ == "__main__":
# test_trigger()
# test_context()
Expand Down
48 changes: 48 additions & 0 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io
import json
import logging
import tempfile
from types import SimpleNamespace
from typing import Any, Optional
from unittest import mock
Expand All @@ -21,6 +22,8 @@
from autogen.agentchat.contrib.capabilities import transform_messages, transforms
from autogen.exception_utils import AgentNameConflict, UndefinedNextAgent

from ..conftest import Credentials


def test_func_call_groupchat():
agent1 = autogen.ConversableAgent(
Expand Down Expand Up @@ -2181,6 +2184,51 @@ def test_manager_resume_message_assignment():
assert list(agent_a.chat_messages.values())[0] == prev_messages[:-1]


@pytest.mark.deepseek
def test_groupchat_with_deepseek_reasoner(
credentials_gpt_4o_mini: Credentials,
credentials_deepseek_reasoner: Credentials,
) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
user_proxy = autogen.UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
code_execution_config={"work_dir": tmp_dir, "use_docker": False},
)

supervisor = autogen.AssistantAgent(
"supervisor",
llm_config={
"config_list": credentials_deepseek_reasoner.config_list,
},
)

assistant = autogen.AssistantAgent(
"assistant",
llm_config={
"config_list": credentials_deepseek_reasoner.config_list,
},
)

groupchat = autogen.GroupChat(
agents=[user_proxy, supervisor, assistant],
messages=["A group chat"],
max_round=5,
)

manager = autogen.GroupChatManager(
groupchat=groupchat,
llm_config={
"config_list": credentials_gpt_4o_mini.config_list,
},
)

result = user_proxy.initiate_chat(
manager, message="""Give me some info about the stock market""", summary_method="reflection_with_llm"
)
assert isinstance(result.summary, str)


if __name__ == "__main__":
# test_func_call_groupchat()
# test_broadcast()
Expand Down
32 changes: 28 additions & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,32 @@ def get_credentials(


def get_config_list_from_env(
env_var_name: str, model: str, api_type: str, filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0
env_var_name: str,
model: str,
api_type: str,
*,
base_url: Optional[str] = None,
rjambrecic marked this conversation as resolved.
Show resolved Hide resolved
filter_dict: Optional[dict[str, Any]] = None,
temperature: float = 0.0,
) -> list[dict[str, Any]]:
if env_var_name in os.environ:
api_key = os.environ[env_var_name]
return [{"api_key": api_key, "model": model, **filter_dict, "api_type": api_type}] # type: ignore[dict-item]
config_list = [{"api_key": api_key, "model": model, **filter_dict, "api_type": api_type}] # type: ignore[dict-item]
if base_url:
config_list[0]["base_url"] = base_url
return config_list

return []


def get_llm_credentials(
env_var_name: str, model: str, api_type: str, filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0
env_var_name: str,
model: str,
api_type: str,
*,
base_url: Optional[str] = None,
filter_dict: Optional[dict[str, Any]] = None,
temperature: float = 0.0,
) -> Credentials:
credentials = get_credentials(filter_dict, temperature, fail_if_empty=False)
config_list = credentials.config_list if credentials else []
Expand All @@ -198,7 +214,14 @@ def get_llm_credentials(

# If no config found, try to get it from the environment
if config_list == []:
config_list = get_config_list_from_env(env_var_name, model, api_type, filter_dict, temperature)
config_list = get_config_list_from_env(
env_var_name,
model,
api_type,
base_url=base_url,
filter_dict=filter_dict,
temperature=temperature,
)

# If still no config found, raise an error
assert config_list, f"No {api_type} config list found and could not be created from an env var {env_var_name}"
Expand Down Expand Up @@ -301,6 +324,7 @@ def credentials_deepseek_reasoner() -> Credentials:
"DEEPSEEK_API_KEY",
model="deepseek-reasoner",
api_type="deepseek",
base_url="https://api.deepseek.com/v1",
rjambrecic marked this conversation as resolved.
Show resolved Hide resolved
filter_dict={"tags": ["deepseek-reasoner"]},
)

Expand Down
90 changes: 90 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# SPDX-License-Identifier: MIT
#!/usr/bin/env python3 -m pytest

import copy
import os
import shutil
import time
Expand Down Expand Up @@ -342,6 +343,95 @@ def raise_bad_request_error(error_message: str) -> None:
wrapped_raise_bad_request_error(error_message=error_message)


class TestDeepSeekPatch:
@pytest.mark.parametrize(
"messages, expected_messages",
[
(
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
),
(
[
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
[
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
),
(
[
{"role": "assistant", "content": "Help me with my problem."},
{"role": "system", "content": "You are an AG2 Agent."},
],
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "assistant", "content": "Help me with my problem."},
],
),
(
[
{"role": "assistant", "content": "Help me with my problem."},
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "assistant", "content": "Help me with my problem."},
{"role": "user", "content": "Help me with my problem."},
],
),
],
)
def test_move_system_message_to_beginning(
self, messages: list[dict[str, str]], expected_messages: list[dict[str, str]]
) -> None:
OpenAIClient._move_system_message_to_beginning(messages)
assert messages == expected_messages

@pytest.mark.parametrize(
"model, should_patch",
[
("deepseek-reasoner", True),
("deepseek", False),
("something-else", False),
],
)
def test_patch_messages_for_deepseek_reasoner(self, model: str, should_patch: bool) -> None:
kwargs = {
"messages": [
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "system", "content": "You are an AG2 Agent System."},
{"role": "user", "content": "Help me with my problem."},
],
"model": model,
}

if should_patch:
expected_kwargs = {
"messages": [
{"role": "system", "content": "You are an AG2 Agent System."},
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "assistant", "content": "Help me with my problem."},
{"role": "user", "content": "continue"},
],
"model": "deepseek-reasoner",
}
else:
expected_kwargs = copy.deepcopy(kwargs)

kwargs = OpenAIClient._patch_messages_for_deepseek_reasoner(**kwargs)
assert kwargs == expected_kwargs


class TestO1:
@pytest.fixture
def mock_oai_client(self, mock_credentials: Credentials) -> OpenAIClient:
Expand Down
6 changes: 0 additions & 6 deletions test/test_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,3 @@ def test_credentials_from_test_param_fixture(
assert first_config["api_type"] == "anthropic"
else:
assert False, f"Unknown LLM fixture: {current_llm}"


@pytest.mark.deepseek
def test_credentials_deepseek_reasoner_api_key_is_set(credentials_deepseek_reasoner: Credentials) -> None:
assert len(credentials_deepseek_reasoner.config_list) > 0
assert credentials_deepseek_reasoner.config_list[0]["api_key"] is not None
Loading