Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/mcp_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ class AnthropicSettings(BaseModel):

base_url: str | None = None

cache_mode: Literal["off", "prompt", "auto"] = "off"
"""
Controls how caching is applied for Anthropic models when prompt_caching is enabled globally.
- "off": No caching, even if global prompt_caching is true.
- "prompt": Caches the initial system/user prompt. Useful for large, static prompts.
- "auto": Caches the last three messages. Default behavior if prompt_caching is true.
"""

model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)


Expand Down
19 changes: 16 additions & 3 deletions src/mcp_agent/llm/augmented_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
PARAM_USE_HISTORY = "use_history"
PARAM_MAX_ITERATIONS = "max_iterations"
PARAM_TEMPLATE_VARS = "template_vars"

# Base set of fields that should always be excluded
BASE_EXCLUDE_FIELDS = {PARAM_METADATA}

Expand Down Expand Up @@ -361,16 +362,28 @@ def prepare_provider_arguments(
# Start with base arguments
arguments = base_args.copy()

# Use provided exclude_fields or fall back to base exclusions
exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy()
# Combine base exclusions with provider-specific exclusions
final_exclude_fields = self.BASE_EXCLUDE_FIELDS.copy()
if exclude_fields:
final_exclude_fields.update(exclude_fields)

# Add all fields from params that aren't explicitly excluded
params_dict = request_params.model_dump(exclude=exclude_fields)
# Ensure model_dump only includes set fields if that's the desired behavior,
# or adjust exclude_unset=True/False as needed.
# Default Pydantic v2 model_dump is exclude_unset=False
params_dict = request_params.model_dump(exclude=final_exclude_fields)

for key, value in params_dict.items():
# Only add if not None and not already in base_args (base_args take precedence)
# or if None is a valid value for the provider, this logic might need adjustment.
if value is not None and key not in arguments:
arguments[key] = value
elif value is not None and key in arguments and arguments[key] is None:
# Allow overriding a None in base_args with a set value from params
arguments[key] = value

# Finally, add any metadata fields as a last layer of overrides
# This ensures metadata can override anything previously set if keys conflict.
if request_params.metadata:
arguments.update(request_params.metadata)

Expand Down
102 changes: 87 additions & 15 deletions src/mcp_agent/llm/providers/augmented_llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,68 @@ async def _anthropic_completion(
# if use_history is True
messages.extend(self.history.get(include_completion_history=params.use_history))

messages.append(message_param)
messages.append(message_param) # message_param is the current user turn

# Prepare for caching based on provider-specific cache_mode
apply_cache_to_system_prompt = False
messages_to_cache_indices: List[int] = []

if self.context.config and self.context.config.anthropic:
cache_mode = self.context.config.anthropic.cache_mode
self.logger.debug(f"Anthropic cache_mode: {cache_mode}")

if cache_mode == "auto":
apply_cache_to_system_prompt = True # Cache system prompt
if messages: # If there are any messages
# Cache the last 3 messages
messages_to_cache_indices.extend(
range(max(0, len(messages) - 3), len(messages))
)
self.logger.debug(
f"Auto mode: Caching system prompt (if present) and last three messages at indices: {messages_to_cache_indices}"
)
elif cache_mode == "prompt":
# Find the first user message in the fully constructed messages list
for idx, msg in enumerate(messages):
if isinstance(msg, dict) and msg.get("role") == "user":
messages_to_cache_indices.append(idx)
self.logger.debug(
f"Prompt mode: Caching first user message in constructed prompt at index: {idx}"
)
break
elif cache_mode == "off":
self.logger.debug("Anthropic cache_mode is 'off'. No caching will be applied.")
else: # Should not happen due to Literal validation
self.logger.warning(
f"Unknown Anthropic cache_mode: {cache_mode}. No caching will be applied."
)
else:
self.logger.debug("Anthropic settings not found. No caching will be applied.")

# Apply cache_control to selected messages
for msg_idx in messages_to_cache_indices:
message_to_cache = messages[msg_idx]
if (
isinstance(message_to_cache, dict)
and "content" in message_to_cache
and isinstance(message_to_cache["content"], list)
and message_to_cache["content"]
):
# Apply to the last content block of the message
last_content_block = message_to_cache["content"][-1]
if isinstance(last_content_block, dict):
self.logger.debug(
f"Applying cache_control to last content block of message at index {msg_idx}."
)
last_content_block["cache_control"] = {"type": "ephemeral"}
else:
self.logger.warning(
f"Could not apply cache_control to message at index {msg_idx}: Last content block is not a dictionary."
)
else:
self.logger.warning(
f"Could not apply cache_control to message at index {msg_idx}: Invalid message structure or no content."
)

tool_list: ListToolsResult = await self.aggregator.list_tools()
available_tools: List[ToolParam] = [
Expand All @@ -144,6 +205,20 @@ async def _anthropic_completion(
"tools": available_tools,
}

# Apply cache_control to system prompt for "auto" mode
if apply_cache_to_system_prompt and base_args["system"]:
if isinstance(base_args["system"], str):
base_args["system"] = [
{
"type": "text",
"text": base_args["system"],
"cache_control": {"type": "ephemeral"},
}
]
self.logger.debug(
"Applying cache_control to system prompt by wrapping it in a list of content blocks."
)

if params.maxTokens is not None:
base_args["max_tokens"] = params.maxTokens

Expand Down Expand Up @@ -178,13 +253,13 @@ async def _anthropic_completion(
# Convert other errors to text response
error_message = f"Error during generation: {error_details}"
response = Message(
id="error", # Required field
model="error", # Required field
id="error",
model="error",
role="assistant",
type="message",
content=[TextBlock(type="text", text=error_message)],
stop_reason="end_turn", # Must be one of the allowed values
usage=Usage(input_tokens=0, output_tokens=0), # Required field
stop_reason="end_turn",
usage=Usage(input_tokens=0, output_tokens=0),
)

self.logger.debug(
Expand All @@ -194,7 +269,7 @@ async def _anthropic_completion(

response_as_message = self.convert_message_to_message_param(response)
messages.append(response_as_message)
if response.content[0].type == "text":
if response.content and response.content[0].type == "text":
responses.append(TextContent(type="text", text=response.content[0].text))

if response.stop_reason == "end_turn":
Expand Down Expand Up @@ -254,12 +329,13 @@ async def _anthropic_completion(

# Process all tool calls and collect results
tool_results = []
for i, content in enumerate(tool_uses):
tool_name = content.name
tool_args = content.input
tool_use_id = content.id
# Use a different loop variable for tool enumeration if 'i' is outer loop counter
for tool_idx, content_block in enumerate(tool_uses):
tool_name = content_block.name
tool_args = content_block.input
tool_use_id = content_block.id

if i == 0: # Only show message for first tool use
if tool_idx == 0: # Only show message for first tool use
await self.show_assistant_message(message_text, tool_name)

self.show_tool_call(available_tools, tool_name, tool_args)
Expand All @@ -284,11 +360,7 @@ async def _anthropic_completion(
if params.use_history:
# Get current prompt messages
prompt_messages = self.history.get(include_completion_history=False)

# Calculate new conversation messages (excluding prompts)
new_messages = messages[len(prompt_messages) :]

# Update conversation history
self.history.set(new_messages)

self._log_chat_finished(model=model)
Expand Down
Loading