diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index 9a5f842b..2d7bad25 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -113,6 +113,14 @@ class AnthropicSettings(BaseModel): base_url: str | None = None + cache_mode: Literal["off", "prompt", "auto"] = "off" + """ + Controls how caching is applied for Anthropic models when prompt_caching is enabled globally. + - "off": No caching, even if global prompt_caching is true. + - "prompt": Caches the initial system/user prompt. Useful for large, static prompts. + - "auto": Caches the last three messages. Default behavior if prompt_caching is true. + """ + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) diff --git a/src/mcp_agent/llm/augmented_llm.py b/src/mcp_agent/llm/augmented_llm.py index 6ae9b646..7258af29 100644 --- a/src/mcp_agent/llm/augmented_llm.py +++ b/src/mcp_agent/llm/augmented_llm.py @@ -95,6 +95,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT PARAM_USE_HISTORY = "use_history" PARAM_MAX_ITERATIONS = "max_iterations" PARAM_TEMPLATE_VARS = "template_vars" + # Base set of fields that should always be excluded BASE_EXCLUDE_FIELDS = {PARAM_METADATA} @@ -361,16 +362,28 @@ def prepare_provider_arguments( # Start with base arguments arguments = base_args.copy() - # Use provided exclude_fields or fall back to base exclusions - exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy() + # Combine base exclusions with provider-specific exclusions + final_exclude_fields = self.BASE_EXCLUDE_FIELDS.copy() + if exclude_fields: + final_exclude_fields.update(exclude_fields) # Add all fields from params that aren't explicitly excluded - params_dict = request_params.model_dump(exclude=exclude_fields) + # Ensure model_dump only includes set fields if that's the desired behavior, + # or adjust exclude_unset=True/False as needed. + # Default Pydantic v2 model_dump is exclude_unset=False + params_dict = request_params.model_dump(exclude=final_exclude_fields) + for key, value in params_dict.items(): + # Only add if not None and not already in base_args (base_args take precedence) + # or if None is a valid value for the provider, this logic might need adjustment. if value is not None and key not in arguments: arguments[key] = value + elif value is not None and key in arguments and arguments[key] is None: + # Allow overriding a None in base_args with a set value from params + arguments[key] = value # Finally, add any metadata fields as a last layer of overrides + # This ensures metadata can override anything previously set if keys conflict. if request_params.metadata: arguments.update(request_params.metadata) diff --git a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py index b719a6cf..133111fd 100644 --- a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py +++ b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py @@ -117,7 +117,68 @@ async def _anthropic_completion( # if use_history is True messages.extend(self.history.get(include_completion_history=params.use_history)) - messages.append(message_param) + messages.append(message_param) # message_param is the current user turn + + # Prepare for caching based on provider-specific cache_mode + apply_cache_to_system_prompt = False + messages_to_cache_indices: List[int] = [] + + if self.context.config and self.context.config.anthropic: + cache_mode = self.context.config.anthropic.cache_mode + self.logger.debug(f"Anthropic cache_mode: {cache_mode}") + + if cache_mode == "auto": + apply_cache_to_system_prompt = True # Cache system prompt + if messages: # If there are any messages + # Cache the last 3 messages + messages_to_cache_indices.extend( + range(max(0, len(messages) - 3), len(messages)) + ) + self.logger.debug( + f"Auto mode: Caching system prompt (if present) and last three messages at indices: {messages_to_cache_indices}" + ) + elif cache_mode == "prompt": + # Find the first user message in the fully constructed messages list + for idx, msg in enumerate(messages): + if isinstance(msg, dict) and msg.get("role") == "user": + messages_to_cache_indices.append(idx) + self.logger.debug( + f"Prompt mode: Caching first user message in constructed prompt at index: {idx}" + ) + break + elif cache_mode == "off": + self.logger.debug("Anthropic cache_mode is 'off'. No caching will be applied.") + else: # Should not happen due to Literal validation + self.logger.warning( + f"Unknown Anthropic cache_mode: {cache_mode}. No caching will be applied." + ) + else: + self.logger.debug("Anthropic settings not found. No caching will be applied.") + + # Apply cache_control to selected messages + for msg_idx in messages_to_cache_indices: + message_to_cache = messages[msg_idx] + if ( + isinstance(message_to_cache, dict) + and "content" in message_to_cache + and isinstance(message_to_cache["content"], list) + and message_to_cache["content"] + ): + # Apply to the last content block of the message + last_content_block = message_to_cache["content"][-1] + if isinstance(last_content_block, dict): + self.logger.debug( + f"Applying cache_control to last content block of message at index {msg_idx}." + ) + last_content_block["cache_control"] = {"type": "ephemeral"} + else: + self.logger.warning( + f"Could not apply cache_control to message at index {msg_idx}: Last content block is not a dictionary." + ) + else: + self.logger.warning( + f"Could not apply cache_control to message at index {msg_idx}: Invalid message structure or no content." + ) tool_list: ListToolsResult = await self.aggregator.list_tools() available_tools: List[ToolParam] = [ @@ -144,6 +205,20 @@ async def _anthropic_completion( "tools": available_tools, } + # Apply cache_control to system prompt for "auto" mode + if apply_cache_to_system_prompt and base_args["system"]: + if isinstance(base_args["system"], str): + base_args["system"] = [ + { + "type": "text", + "text": base_args["system"], + "cache_control": {"type": "ephemeral"}, + } + ] + self.logger.debug( + "Applying cache_control to system prompt by wrapping it in a list of content blocks." + ) + if params.maxTokens is not None: base_args["max_tokens"] = params.maxTokens @@ -178,13 +253,13 @@ async def _anthropic_completion( # Convert other errors to text response error_message = f"Error during generation: {error_details}" response = Message( - id="error", # Required field - model="error", # Required field + id="error", + model="error", role="assistant", type="message", content=[TextBlock(type="text", text=error_message)], - stop_reason="end_turn", # Must be one of the allowed values - usage=Usage(input_tokens=0, output_tokens=0), # Required field + stop_reason="end_turn", + usage=Usage(input_tokens=0, output_tokens=0), ) self.logger.debug( @@ -194,7 +269,7 @@ async def _anthropic_completion( response_as_message = self.convert_message_to_message_param(response) messages.append(response_as_message) - if response.content[0].type == "text": + if response.content and response.content[0].type == "text": responses.append(TextContent(type="text", text=response.content[0].text)) if response.stop_reason == "end_turn": @@ -254,12 +329,13 @@ async def _anthropic_completion( # Process all tool calls and collect results tool_results = [] - for i, content in enumerate(tool_uses): - tool_name = content.name - tool_args = content.input - tool_use_id = content.id + # Use a different loop variable for tool enumeration if 'i' is outer loop counter + for tool_idx, content_block in enumerate(tool_uses): + tool_name = content_block.name + tool_args = content_block.input + tool_use_id = content_block.id - if i == 0: # Only show message for first tool use + if tool_idx == 0: # Only show message for first tool use await self.show_assistant_message(message_text, tool_name) self.show_tool_call(available_tools, tool_name, tool_args) @@ -284,11 +360,7 @@ async def _anthropic_completion( if params.use_history: # Get current prompt messages prompt_messages = self.history.get(include_completion_history=False) - - # Calculate new conversation messages (excluding prompts) new_messages = messages[len(prompt_messages) :] - - # Update conversation history self.history.set(new_messages) self._log_chat_finished(model=model)