diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py index 00e518ea3..f1a3c23ba 100644 --- a/ms_agent/llm/anthropic_llm.py +++ b/ms_agent/llm/anthropic_llm.py @@ -1,7 +1,6 @@ import inspect from typing import Any, Dict, Generator, Iterator, List, Optional, Union -import json5 from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import assert_package_exist, retry @@ -109,10 +108,20 @@ def _call_llm(self, if formatted_messages[0]['role'] == 'system': system = formatted_messages[0]['content'] formatted_messages = formatted_messages[1:] + + max_tokens = kwargs.pop('max_tokens', 16000) + extra_body = kwargs.get('extra_body', {}) + enable_thinking = extra_body.get('enable_thinking', False) + thinking_budget = extra_body.get('thinking_budget', max_tokens) + params = { 'model': self.model, 'messages': formatted_messages, - 'max_tokens': kwargs.pop('max_tokens', 1024), + 'max_tokens': max_tokens, + 'thinking': { + 'type': 'enabled' if enable_thinking else 'disabled', + 'budget_tokens': thinking_budget + } } if system: @@ -163,6 +172,8 @@ def _stream_format_output_message(self, ) tool_call_id_map = {} # index -> tool_call_id (用于去重 yield) with stream_manager as stream: + full_content = '' + full_thinking = '' for event in stream: event_type = getattr(event, 'type') if event_type == 'message_start': @@ -170,8 +181,13 @@ def _stream_format_output_message(self, current_message.id = msg.id tool_call_id_map = {} yield current_message - elif event_type == 'text': - current_message.content = event.snapshot + elif event_type == 'content_block_delta': + if event.delta.type == 'thinking_delta': + full_thinking += event.delta.thinking + current_message.reasoning_content = full_thinking + elif event.delta.type == 'text_delta': + full_content += event.delta.text + current_message.content = full_content yield current_message elif event_type == 'message_stop': final_msg = getattr(event, 'message') diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index ba434f4b9..cec6a381d 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -9,7 +9,7 @@ from modelscope.utils.test_utils import test_level -API_CALL_MAX_TOKEN = 50 +API_CALL_MAX_TOKEN = 500 class OpenaiLLM(unittest.TestCase): @@ -124,34 +124,23 @@ def test_tool_no_stream(self): print(res) assert (len(res.tool_calls)) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_agent_multi_round(self): - import asyncio - - async def main(): - agent = LLMAgent(config=self.conf, mcp_config=self.mcp_config) - if hasattr(agent.config, 'callbacks'): - agent.config.callbacks.remove('input_callback') # noqa - res = await agent.run('访问www.baidu.com') - print(res) - assert ('robots.txt' in res[-1].content) - - asyncio.run(main()) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_stream_agent_multi_round(self): + def test_stream_agent_multi_round_with_thinking(self): import asyncio from copy import deepcopy async def main(): conf2 = deepcopy(self.conf) - conf2.generation_config.stream = True + conf2.llm.model = 'Qwen/Qwen3-235B-A22B' + conf2.generation_config.extra_body.enable_thinking = True agent = LLMAgent(config=conf2, mcp_config=self.mcp_config) if hasattr(agent.config, 'callbacks'): agent.config.callbacks.remove('input_callback') # noqa - res = await agent.run('访问www.baidu.com') - print('res:', res) - assert ('robots.txt' in res[-1].content) + res = await agent.run('访问www.baidu.com', stream=True) + async for chunk in res: + print('res: ', chunk) + assert ('robots.txt' in chunk[-1].content) + assert (chunk[-1].reasoning_content) asyncio.run(main())