Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ms_agent/agent/agent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ llm:
modelscope_base_url: https://api-inference.modelscope.cn/v1

generation_config:
top_p: 0.6
temperature: 0.2
temperature: 0.3
top_k: 20
stream: true
extra_body:
Expand Down
103 changes: 53 additions & 50 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from ms_agent.callbacks import Callback, callbacks_mapping
from ms_agent.llm.llm import LLM
from ms_agent.llm.utils import Message
from ms_agent.memory import Memory, memory_mapping
from ms_agent.memory.mem0ai import Mem0Memory, SharedMemoryManager
from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping
from ms_agent.memory.memory_manager import SharedMemoryManager
from ms_agent.rag.base import RAG
from ms_agent.rag.utils import rag_mapping
from ms_agent.tools import ToolManager
Expand Down Expand Up @@ -331,36 +331,36 @@ async def load_memory(self):
self.config: DictConfig
if hasattr(self.config, 'memory'):
for _memory in (self.config.memory or []):
memory_type = getattr(_memory, 'name', 'default_memory')
assert memory_type in memory_mapping, (
f'{memory_type} not in memory_mapping, '
mem_instance_type = getattr(_memory, 'name', None)
if mem_instance_type is None:
mem_instance_type = 'default_memory'
setattr(_memory, 'name', 'default_memory')
assert mem_instance_type in memory_mapping, (
f'{mem_instance_type} not in memory_mapping, '
f'which supports: {list(memory_mapping.keys())}')

# Use LLM config if no special configuration is specified
# Use LLM config if no memory llm configuration is specified
llm_config = getattr(_memory, 'llm', None)
if llm_config is None:
service = self.config.llm.service
config_dict = {
'model':
_memory.summary_model,
'provider':
'openai',
self.config.llm.model,
'service':
service,
'openai_base_url':
getattr(self.config.llm, f'{service}_base_url', None),
'openai_api_key':
getattr(self.config.llm, f'{service}_api_key', None),
'max_tokens':
_memory.max_tokens,
getattr(self.config.generation_config, f'max_tokens',
None),
}
llm_config_obj = OmegaConf.create(config_dict)
setattr(_memory, 'llm', llm_config_obj)
if memory_type == 'mem0':
shared_memory = SharedMemoryManager.get_shared_memory(
_memory)
self.memory_tools.append(shared_memory)
else:
self.memory_tools.append(
memory_mapping[memory_type](_memory))

shared_memory = SharedMemoryManager.get_shared_memory(_memory)
self.memory_tools.append(shared_memory)

async def prepare_rag(self):
"""Load and initialize the RAG component from the config."""
Expand Down Expand Up @@ -547,6 +547,40 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]:
break
return user_id

def _get_step_memory_info(self, memory_config: DictConfig):
user_id, agent_id, run_id, memory_type = get_memory_meta_safe(
memory_config, 'add_after_step')
user_id = user_id or getattr(memory_config, 'user_id', None)
return user_id, agent_id, run_id, memory_type

def _get_run_memory_info(self, memory_config: DictConfig):
user_id, agent_id, run_id, memory_type = get_memory_meta_safe(
memory_config, 'add_after_task')
user_id = user_id or getattr(memory_config, 'user_id', None)
agent_id = agent_id or self.tag
memory_type = memory_type or 'procedural_memory'
return user_id, agent_id, run_id, memory_type

async def add_memory(self, messages: List[Message], **kwargs):
if hasattr(self.config, 'memory') and self.config.memory:
tools_num = len(
self.memory_tools
) if self.memory_tools else 0 # Check index bounds before access to avoid IndexError
for idx, memory_config in enumerate(self.config.memory):
if self.runtime.should_stop:
user_id, agent_id, run_id, memory_type = self._get_run_memory_info(
memory_config)
else:
user_id, agent_id, run_id, memory_type = self._get_step_memory_info(
memory_config)
if idx < tools_num:
await self.memory_tools[idx].add(
messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
memory_type=memory_type)

def save_history(self, messages: List[Message], **kwargs):
"""
Save current chat history to disk for future resuming.
Expand All @@ -567,37 +601,6 @@ def save_history(self, messages: List[Message], **kwargs):
save_history(
self.output_dir, task=self.tag, config=config, messages=messages)

def save_memory(self, messages: List[Message]):
"""
Save memories to disk for future resuming.

Args:
messages (List[Message]): Current message history to save.
"""
messages = deepcopy(messages)
for message in messages:
# Prevent the arguments are not json
if message.tool_calls:
for tool_call in message.tool_calls:
try:
if tool_call['arguments']:
json.loads(tool_call['arguments'])
except Exception:
tool_call['arguments'] = '{}'

if self.memory_tools:
if self.runtime.should_stop:
for memory_tool in self.memory_tools:
if isinstance(memory_tool, Mem0Memory):
memory_tool.add_memories_from_procedural(
messages, self.get_user_id(), self.tag,
'procedural_memory')
else:
for memory_tool in self.memory_tools:
if isinstance(memory_tool, Mem0Memory):
memory_tool.add_memories_from_conversation(
messages, self.get_user_id())

async def run_loop(self, messages: Union[List[Message], str],
**kwargs) -> AsyncGenerator[Any, Any]:
"""Run the agent, mainly contains a llm calling and tool calling loop.
Expand Down Expand Up @@ -639,7 +642,7 @@ async def run_loop(self, messages: Union[List[Message], str],
yield messages
self.runtime.round += 1
# save memory and history
self.save_memory(messages)
await self.add_memory(messages, **kwargs)
self.save_history(messages)

# +1 means the next round the assistant may give a conclusion
Expand All @@ -655,7 +658,7 @@ async def run_loop(self, messages: Union[List[Message], str],
yield messages

# save memory
self.save_memory(messages)
await self.add_memory(messages, **kwargs)

await self.on_task_end(messages)
await self.cleanup_tools()
Expand Down
4 changes: 3 additions & 1 deletion ms_agent/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ms_agent.llm import LLM
from ms_agent.llm.utils import Message, Tool, ToolCall
from ms_agent.utils import assert_package_exist, retry
from ms_agent.utils.constants import get_service_config
from omegaconf import DictConfig, OmegaConf


Expand All @@ -22,7 +23,8 @@ def __init__(

self.model: str = config.llm.model

base_url = base_url or config.llm.get('anthropic_base_url')
base_url = base_url or config.llm.get(
'anthropic_base_url') or get_service_config('anthropic').base_url
api_key = api_key or config.llm.get('anthropic_api_key')

if not api_key:
Expand Down
4 changes: 3 additions & 1 deletion ms_agent/llm/dashscope_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ms_agent.llm.openai_llm import OpenAI
from ms_agent.llm.utils import Message, Tool
from ms_agent.utils.constants import get_service_config
from omegaconf import DictConfig


Expand All @@ -11,7 +12,8 @@ class DashScope(OpenAI):
def __init__(self, config: DictConfig):
super().__init__(
config,
base_url=config.llm.dashscope_base_url,
base_url=config.llm.modelscope_base_url
or get_service_config('dashscope').base_url,
api_key=config.llm.dashscope_api_key)

def _call_llm_for_continue_gen(self,
Expand Down
8 changes: 5 additions & 3 deletions ms_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def from_config(cls, config: DictConfig) -> Any:
Returns:
The LLM instance.
"""
from .model_mapping import all_services_mapping
assert config.llm.service in all_services_mapping
return all_services_mapping[config.llm.service](config)
from .model_mapping import all_services_mapping, OpenAI
if config.llm.get('service') in all_services_mapping:
return all_services_mapping[config.llm.service](config)
else:
return OpenAI(config)
2 changes: 2 additions & 0 deletions ms_agent/llm/model_mapping.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from ms_agent.llm.anthropic_llm import Anthropic
from ms_agent.llm.dashscope_llm import DashScope
from ms_agent.llm.modelscope_llm import ModelScope
from ms_agent.llm.openai_llm import OpenAI

all_services_mapping = {
'modelscope': ModelScope,
'openai': OpenAI,
'anthropic': Anthropic,
'dashscope': DashScope,
}
4 changes: 3 additions & 1 deletion ms_agent/llm/modelscope_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from ms_agent.llm.openai_llm import OpenAI
from ms_agent.utils.constants import get_service_config
from omegaconf import DictConfig


Expand All @@ -11,5 +12,6 @@ def __init__(self, config: DictConfig):
) and config.llm.modelscope_api_key is not None, 'Please provide `modelscope_api_key` in env or cmd.'
super().__init__(
config,
base_url=config.llm.modelscope_base_url,
base_url=config.llm.modelscope_base_url
or get_service_config('modelscope').base_url,
api_key=config.llm.modelscope_api_key)
23 changes: 4 additions & 19 deletions ms_agent/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ms_agent.llm.utils import Message, Tool, ToolCall
from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist,
get_logger, retry)
from ms_agent.utils.constants import get_service_config
from omegaconf import DictConfig, OmegaConf
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall, Function)
Expand Down Expand Up @@ -41,7 +42,8 @@ def __init__(
self.model: str = config.llm.model
self.max_continue_runs = getattr(config.llm, 'max_continue_runs',
None) or MAX_CONTINUE_RUNS
base_url = base_url or config.llm.openai_base_url
base_url = base_url or config.llm.openai_base_url or get_service_config(
'openai').base_url
api_key = api_key or config.llm.openai_api_key

self.client = openai.OpenAI(
Expand Down Expand Up @@ -427,30 +429,13 @@ def _format_input_message(self,
for message in messages:
if isinstance(message, Message):
message.content = message.content.strip()
message = message.to_dict()

if message.get('tool_calls'):
tool_calls = list()
for tool_call in message['tool_calls']:
function_data: Function = {
'name': tool_call['tool_name'],
'arguments': tool_call['arguments']
}
tool_call: ChatCompletionMessageToolCall = {
'id': tool_call['id'],
'function': function_data,
'type': tool_call['type'],
}
tool_calls.append(tool_call)
message['tool_calls'] = tool_calls
message = message.to_dict_clean()

message = {
key: value.strip() if isinstance(value, str) else value
for key, value in message.items()
if key in self.input_msg and value
}
if 'content' not in message:
message['content'] = ''

openai_messages.append(message)

Expand Down
30 changes: 28 additions & 2 deletions ms_agent/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional

import json
from typing_extensions import Literal, Required, TypedDict


class ToolCall(TypedDict, total=False):
id: str = 'default_id'
index: int = 0
type: str = 'function'
tool_name: Required[str]
arguments: str = ''
tool_name: str = ''
arguments: str = '{}'


class Tool(TypedDict, total=False):
Expand Down Expand Up @@ -52,3 +53,28 @@ class Message:

def to_dict(self):
return asdict(self)

def to_dict_clean(self):
raw_dict = asdict(self)
if raw_dict.get('tool_calls'):
for idx, tool_call in enumerate(raw_dict['tool_calls']):
try:
if tool_call['arguments']:
json.loads(tool_call['arguments'])
except Exception:
tool_call['arguments'] = '{}'
raw_dict['tool_calls'][idx] = {
'id': tool_call['id'],
'type': tool_call['type'],
'function': {
'name': tool_call['tool_name'],
'arguments': tool_call['arguments'],
}
}
required = ['content', 'role']
rm = ['completion_tokens', 'prompt_tokens', 'api_calls']
return {
key: value
for key, value in raw_dict.items()
if (value or key in required) and key not in rm
}
2 changes: 1 addition & 1 deletion ms_agent/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .base import Memory
from .utils import DefaultMemory, memory_mapping
from .utils import DefaultMemory, get_memory_meta_safe, memory_mapping
Loading
Loading