diff --git a/src/mcp_agent/core/agent_app.py b/src/mcp_agent/core/agent_app.py index 729b1a1b..35a87c21 100644 --- a/src/mcp_agent/core/agent_app.py +++ b/src/mcp_agent/core/agent_app.py @@ -47,6 +47,10 @@ def __getattr__(self, name: str) -> Agent: if name in self._agents: return self._agents[name] raise AttributeError(f"Agent '{name}' not found") + + def add_agent(self, name: str, agent: Agent) -> None: + """Add a new agent to the app.""" + self._agents[name] = agent async def __call__( self, diff --git a/src/mcp_agent/core/direct_decorators.py b/src/mcp_agent/core/direct_decorators.py index 4b57f08c..897c1ad3 100644 --- a/src/mcp_agent/core/direct_decorators.py +++ b/src/mcp_agent/core/direct_decorators.py @@ -114,6 +114,16 @@ def _decorator_impl( default: Whether to mark this as the default agent **extra_kwargs: Additional agent/workflow-specific parameters """ + # Check for duplicate agent registration + if name in self.agents: + # We can't use the logger here as it's not configured yet, so print a warning + print(f"Warning: Agent '{name}' is already registered. Skipping duplicate definition.") + + # Return a no-op decorator to prevent re-registration + def no_op_decorator(func: AgentCallable[P, R]) -> AgentCallable[P, R]: + return func + + return no_op_decorator def decorator(func: AgentCallable[P, R]) -> DecoratedAgentProtocol[P, R]: is_async = inspect.iscoroutinefunction(func) diff --git a/src/mcp_agent/core/direct_factory.py b/src/mcp_agent/core/direct_factory.py index 649c184f..dde081a1 100644 --- a/src/mcp_agent/core/direct_factory.py +++ b/src/mcp_agent/core/direct_factory.py @@ -3,6 +3,7 @@ Implements type-safe factories with improved error handling. """ +import re from typing import Any, Callable, Dict, Optional, Protocol, TypeVar from mcp_agent.agents.agent import Agent, AgentConfig @@ -15,7 +16,7 @@ from mcp_agent.agents.workflow.router_agent import RouterAgent from mcp_agent.app import MCPApp from mcp_agent.core.agent_types import AgentType -from mcp_agent.core.exceptions import AgentConfigError +from mcp_agent.core.exceptions import AgentConfigError, ServerInitializationError from mcp_agent.core.validation import get_dependencies_groups from mcp_agent.event_progress import ProgressAction from mcp_agent.llm.augmented_llm import RequestParams @@ -125,13 +126,6 @@ def model_factory_func(model=None, request_params=None): # Get all agents of the specified type for name, agent_data in agents_dict.items(): - logger.info( - f"Loaded {name}", - data={ - "progress_action": ProgressAction.LOADED, - "agent_name": name, - }, - ) # Compare type string from config with Enum value if agent_data["type"] == agent_type.value: @@ -156,6 +150,16 @@ def model_factory_func(model=None, request_params=None): api_key=config.api_key ) result_agents[name] = agent + + # Log successful agent creation + logger.info( + f"Loaded {name}", + data={ + "progress_action": ProgressAction.LOADED, + "agent_name": name, + "target": name, + }, + ) elif agent_type == AgentType.CUSTOM: # Get the class to instantiate @@ -175,6 +179,16 @@ def model_factory_func(model=None, request_params=None): api_key=config.api_key ) result_agents[name] = agent + + # Log successful agent creation + logger.info( + f"Loaded {name}", + data={ + "progress_action": ProgressAction.LOADED, + "agent_name": name, + "target": name, + }, + ) elif agent_type == AgentType.ORCHESTRATOR: # Get base params configured with model settings @@ -370,111 +384,78 @@ async def create_agents_in_dependency_order( # Create agent proxies for each group in dependency order for group in dependencies: - # Create basic agents first - # Note: We compare string values from config with the Enum's string value - if AgentType.BASIC.value in [agents_dict[name]["type"] for name in group]: - basic_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.BASIC.value - }, - AgentType.BASIC, - active_agents, - model_factory_func, - ) - active_agents.update(basic_agents) - - # Create custom agents first - if AgentType.CUSTOM.value in [agents_dict[name]["type"] for name in group]: - basic_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.CUSTOM.value - }, - AgentType.CUSTOM, - active_agents, - model_factory_func, - ) - active_agents.update(basic_agents) - - # Create parallel agents - if AgentType.PARALLEL.value in [agents_dict[name]["type"] for name in group]: - parallel_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.PARALLEL.value - }, - AgentType.PARALLEL, - active_agents, - model_factory_func, - ) - active_agents.update(parallel_agents) - - # Create router agents - if AgentType.ROUTER.value in [agents_dict[name]["type"] for name in group]: - router_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.ROUTER.value - }, - AgentType.ROUTER, - active_agents, - model_factory_func, - ) - active_agents.update(router_agents) - - # Create chain agents - if AgentType.CHAIN.value in [agents_dict[name]["type"] for name in group]: - chain_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.CHAIN.value - }, - AgentType.CHAIN, - active_agents, - model_factory_func, - ) - active_agents.update(chain_agents) - - # Create evaluator-optimizer agents - if AgentType.EVALUATOR_OPTIMIZER.value in [agents_dict[name]["type"] for name in group]: - evaluator_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.EVALUATOR_OPTIMIZER.value - }, - AgentType.EVALUATOR_OPTIMIZER, - active_agents, - model_factory_func, - ) - active_agents.update(evaluator_agents) - - # Create orchestrator agents last since they might depend on other agents - if AgentType.ORCHESTRATOR.value in [agents_dict[name]["type"] for name in group]: - orchestrator_agents = await create_agents_by_type( - app_instance, - { - name: agents_dict[name] - for name in group - if agents_dict[name]["type"] == AgentType.ORCHESTRATOR.value - }, - AgentType.ORCHESTRATOR, - active_agents, - model_factory_func, - ) - active_agents.update(orchestrator_agents) + for agent_name in group: + agent_config = agents_dict[agent_name] + + # Check if any required servers for this agent are unavailable + # Only do this check if polling is enabled (mcp_polling_interval > 0) + if (hasattr(app_instance.fast_agent, 'mcp_polling_interval') and + app_instance.fast_agent.mcp_polling_interval is not None and + app_instance.fast_agent.mcp_polling_interval > 0): + + agent_config_obj = agent_config.get("config") + required_servers = agent_config_obj.servers if agent_config_obj else [] + if any(server in app_instance.fast_agent.unavailable_servers for server in required_servers): + logger.warning( + f"Skipping agent '{agent_name}' because it depends on an unavailable server." + ) + app_instance.fast_agent.deactivated_agents[agent_name] = agent_config + continue + + try: + agent_type = AgentType(agent_config["type"]) + # Create agents of a specific type for the current group + created_agents = await create_agents_by_type( + app_instance, + {agent_name: agent_config}, + agent_type, + active_agents, + model_factory_func, + ) + active_agents.update(created_agents) + + except ServerInitializationError as e: + # Only handle graceful deactivation if polling is enabled + if (hasattr(app_instance.fast_agent, 'mcp_polling_interval') and + app_instance.fast_agent.mcp_polling_interval is not None and + app_instance.fast_agent.mcp_polling_interval > 0): + + # The original error (e.g., ConnectError) is the cause + server_name = "unknown" + + # We need to find the server name from the original exception text + # as ServerInitializationError doesn't carry it directly. + match = re.search(r"MCP Server: '([^']*)'", str(e)) + if match: + server_name = match.group(1) + + app_instance.fast_agent.unavailable_servers.add(server_name) + app_instance.fast_agent.deactivated_agents[agent_name] = agent_config + + logger.debug(f"MCP server '{server_name}' is not available. Agent '{agent_name}' will be deactivated.") + + logger.info( + f"Agent '{agent_name}' deactivated", + data={ + "progress_action": ProgressAction.DEACTIVATED, + "agent_name": agent_name, + }, + ) + else: + # If polling is not enabled, let the error propagate normally (old behavior) + raise + + except Exception as e: + # Only handle graceful deactivation if polling is enabled + if (hasattr(app_instance.fast_agent, 'mcp_polling_interval') and + app_instance.fast_agent.mcp_polling_interval is not None and + app_instance.fast_agent.mcp_polling_interval > 0): + + logger.error(f"Failed to create agent '{agent_name}': {e}") + app_instance.fast_agent.deactivated_agents[agent_name] = agent_config + else: + # If polling is not enabled, let the error propagate normally (old behavior) + raise return active_agents diff --git a/src/mcp_agent/core/fastagent.py b/src/mcp_agent/core/fastagent.py index e2abbbe3..df2dfe09 100644 --- a/src/mcp_agent/core/fastagent.py +++ b/src/mcp_agent/core/fastagent.py @@ -6,14 +6,14 @@ import argparse import asyncio +import re import sys from contextlib import asynccontextmanager from importlib.metadata import version as get_version from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, Optional, TypeVar import yaml -from opentelemetry import trace from mcp_agent import config from mcp_agent.app import MCPApp @@ -60,12 +60,8 @@ validate_server_references, validate_workflow_references, ) +from mcp_agent.event_progress import ProgressAction from mcp_agent.logging.logger import get_logger -from mcp_agent.mcp.prompts.prompt_load import load_prompt_multipart - -if TYPE_CHECKING: - from mcp_agent.agents.agent import Agent - from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart F = TypeVar("F", bound=Callable[..., Any]) # For decorated functions logger = get_logger(__name__) @@ -73,7 +69,7 @@ class FastAgent: """ - A simplified FastAgent implementation that directly creates Agent instances + A simplified FastAgent implementation that directly creates agent instances without using proxies. """ @@ -84,6 +80,7 @@ def __init__( ignore_unknown_args: bool = False, parse_cli_args: bool = True, quiet: bool = False, # Add quiet parameter + mcp_polling_interval: int = None, # Add MCP polling interval parameter ) -> None: """ Initialize the fast-agent application. @@ -97,9 +94,11 @@ def __init__( Set to False when embedding FastAgent in another framework (like FastAPI/Uvicorn) that handles its own arguments. quiet: If True, disable progress display, tool and message logging for cleaner output + mcp_polling_interval: Interval in seconds for checking unavailable MCP servers (default: None - no polling) """ self.args = argparse.Namespace() # Initialize args always - self._programmatic_quiet = quiet # Store the programmatic quiet setting + self._programmatic_quiet = quiet + self.mcp_polling_interval = mcp_polling_interval # Store the programmatic quiet setting # --- Wrap argument parsing logic --- if parse_cli_args: @@ -196,11 +195,20 @@ def __init__( self.config["logger"]["show_chat"] = False self.config["logger"]["show_tools"] = False + # Apply CLI quiet flag to config if it was set + if hasattr(self.args, 'quiet') and self.args.quiet and hasattr(self, "config"): + if "logger" not in self.config: + self.config["logger"] = {} + self.config["logger"]["progress_display"] = False + self.config["logger"]["show_chat"] = False + self.config["logger"]["show_tools"] = False + # Create the app with our local settings self.app = MCPApp( name=name, settings=config.Settings(**self.config) if hasattr(self, "config") else None, ) + self.app.fast_agent = self # Stop progress display immediately if quiet mode is requested if self._programmatic_quiet: @@ -208,6 +216,12 @@ def __init__( progress_display.stop() + # Also stop progress display if CLI quiet flag is set + if hasattr(self.args, 'quiet') and self.args.quiet: + from mcp_agent.progress_display import progress_display + + progress_display.stop() + except yaml.parser.ParserError as e: handle_error( e, @@ -219,6 +233,12 @@ def __init__( # Dictionary to store agent configurations from decorators self.agents: Dict[str, Dict[str, Any]] = {} + # Dictionary to store deactivated agents that failed to load + self.deactivated_agents: Dict[str, Dict[str, Any]] = {} + + # Set to store unavailable server names + self.unavailable_servers: set[str] = set() + def _load_config(self) -> None: """Load configuration from YAML file including secrets using get_settings but without relying on the global cache.""" @@ -257,199 +277,408 @@ def context(self) -> Context: @asynccontextmanager async def run(self): """ - Context manager for running the application. - Initializes all registered agents. + Run the application context, creating agents and handling their lifecycle. + This simplified version directly creates agent instances. """ - active_agents: Dict[str, Agent] = {} + polling_task = None + active_agents = {} had_error = False - await self.app.initialize() - - # Handle quiet mode and CLI model override safely - # Define these *before* they are used, checking if self.args exists and has the attributes - quiet_mode = hasattr(self.args, "quiet") and self.args.quiet - cli_model_override = ( - self.args.model if hasattr(self.args, "model") and self.args.model else None - ) # Define cli_model_override here - tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span(self.name): - try: - async with self.app.run(): - # Apply quiet mode if requested - if ( - quiet_mode - and hasattr(self.app.context, "config") - and hasattr(self.app.context.config, "logger") - ): - # Update our app's config directly - self.app.context.config.logger.progress_display = False - self.app.context.config.logger.show_chat = False - self.app.context.config.logger.show_tools = False - - # Directly disable the progress display singleton - from mcp_agent.progress_display import progress_display - - progress_display.stop() - - # Pre-flight validation - if 0 == len(self.agents): - raise AgentConfigError( - "No agents defined. Please define at least one agent." + + try: + async with self.app.run() as running_app: + # Store the running app instance so agents can access it + self.app = running_app + + # Add initial polling progress entry after app runs to control display order + # This makes it appear 2nd (after FastAgent process, before agents) + if self.mcp_polling_interval is None or self.mcp_polling_interval <= 0: + details = "Disabled" + progress_action = ProgressAction.DEACTIVATED + elif self.unavailable_servers: + details = f"{len(self.unavailable_servers)} unavailable" + progress_action = ProgressAction.READY + else: + details = "All servers online" + progress_action = ProgressAction.READY + + logger.info( + "Polling MCP Servers", + data={ + "progress_action": progress_action, + "agent_name": "MCP Server Polling", + "target": "Polling MCP Servers", + "details": details + } + ) + + # Pre-flight validation + if 0 == len(self.agents): + raise AgentConfigError( + "No agents defined. Please define at least one agent." + ) + validate_server_references(self.context, self.agents) + validate_workflow_references(self.agents) + + # Define a model factory function that can be passed to agent creation + def model_factory_func(model=None, request_params=None): + return get_model_factory( + self.context, + model=model, + request_params=request_params, + default_model=self.config.get("default_model"), + cli_model=self.args.model if hasattr(self.args, "model") else None, + ) + + # Create agents in dependency order + active_agents = await create_agents_in_dependency_order( + app_instance=self.app, + agents_dict=self.agents, + model_factory_func=model_factory_func, + ) + + # After attempting to load all agents, validate provider keys for active agents + validate_provider_keys_post_creation(active_agents) + + # Create the agent app with the successfully created agents + agent_app = AgentApp(agents=active_agents) + + # Handle CLI arguments if they were provided + if hasattr(self, 'args'): + # Handle --server argument (start server mode) + if hasattr(self.args, 'server') and self.args.server: + from mcp_agent.mcp_server import AgentMCPServer + + # Create and start the MCP server + server = AgentMCPServer( + agent_app=agent_app, + server_name=self.name, + server_description=f"MCP Server for {self.name}", ) - validate_server_references(self.context, self.agents) - validate_workflow_references(self.agents) - - # Get a model factory function - # Now cli_model_override is guaranteed to be defined - def model_factory_func(model=None, request_params=None): - return get_model_factory( - self.context, - model=model, - request_params=request_params, - cli_model=cli_model_override, # Use the variable defined above + + # Run the server with the specified transport + await server.run_async( + transport=self.args.transport, + host=self.args.host, + port=self.args.port, ) - - # Create all agents in dependency order - active_agents = await create_agents_in_dependency_order( - self.app, - self.agents, - model_factory_func, - ) + sys.exit(0) - # Validate API keys after agent creation - validate_provider_keys_post_creation(active_agents) - - # Create a wrapper with all agents for simplified access - wrapper = AgentApp(active_agents) - - # Handle command line options that should be processed after agent initialization - - # Handle --server option - # Check if parse_cli_args was True before checking self.args.server - if hasattr(self.args, "server") and self.args.server: - try: - # Print info message if not in quiet mode - if not quiet_mode: - print(f"Starting FastAgent '{self.name}' in server mode") - print(f"Transport: {self.args.transport}") - if self.args.transport == "sse": - print(f"Listening on {self.args.host}:{self.args.port}") - print("Press Ctrl+C to stop") - - # Create the MCP server - from mcp_agent.mcp_server import AgentMCPServer - - mcp_server = AgentMCPServer( - agent_app=wrapper, - server_name=f"{self.name}-MCP-Server", - ) - - # Run the server directly (this is a blocking call) - await mcp_server.run_async( - transport=self.args.transport, - host=self.args.host, - port=self.args.port, - ) - except KeyboardInterrupt: - if not quiet_mode: - print("\nServer stopped by user (Ctrl+C)") - except Exception as e: - if not quiet_mode: - import traceback - - traceback.print_exc() - print(f"\nServer stopped with error: {e}") - - # Exit after server shutdown - raise SystemExit(0) - - # Handle direct message sending if --message is provided - if hasattr(self.args, "message") and self.args.message: - agent_name = self.args.agent - message = self.args.message - - if agent_name not in active_agents: + # Handle --message argument + if hasattr(self.args, 'message') and self.args.message: + agent_name = self.args.agent if hasattr(self.args, 'agent') and self.args.agent != "default" else None + + # Validate agent exists if specific agent was requested + if agent_name and agent_name not in active_agents: available_agents = ", ".join(active_agents.keys()) - print( - f"\n\nError: Agent '{agent_name}' not found. Available agents: {available_agents}" - ) - raise SystemExit(1) - + print(f"\n\nError: Agent '{agent_name}' not found. Available agents: {available_agents}") + sys.exit(1) + try: - # Get response from the agent - agent = active_agents[agent_name] - response = await agent.send(message) - - # In quiet mode, just print the raw response - # The chat display should already be turned off by the configuration - if self.args.quiet: + response = await agent_app.send(self.args.message, agent_name=agent_name) + if hasattr(self.args, 'quiet') and self.args.quiet: print(f"{response}") - - raise SystemExit(0) + sys.exit(0) except Exception as e: print(f"\n\nError sending message to agent '{agent_name}': {str(e)}") - raise SystemExit(1) - - if hasattr(self.args, "prompt_file") and self.args.prompt_file: - agent_name = self.args.agent - prompt: List[PromptMessageMultipart] = load_prompt_multipart( - Path(self.args.prompt_file) - ) - if agent_name not in active_agents: + sys.exit(1) + + # Handle --prompt-file argument + if hasattr(self.args, 'prompt_file') and self.args.prompt_file: + from mcp_agent.mcp.prompts.prompt_load import load_prompt_multipart + + agent_name = self.args.agent if hasattr(self.args, 'agent') and self.args.agent != "default" else None + + # Validate agent exists if specific agent was requested + if agent_name and agent_name not in active_agents: available_agents = ", ".join(active_agents.keys()) - print( - f"\n\nError: Agent '{agent_name}' not found. Available agents: {available_agents}" - ) - raise SystemExit(1) - + print(f"\n\nError: Agent '{agent_name}' not found. Available agents: {available_agents}") + sys.exit(1) + try: - # Get response from the agent - agent = active_agents[agent_name] - response = await agent.generate(prompt) - - # In quiet mode, just print the raw response - # The chat display should already be turned off by the configuration - if self.args.quiet: + prompt = load_prompt_multipart(Path(self.args.prompt_file)) + response = await agent_app._agent(agent_name).generate(prompt) + if hasattr(self.args, 'quiet') and self.args.quiet: print(f"{response.last_text()}") - - raise SystemExit(0) + sys.exit(0) except Exception as e: print(f"\n\nError sending message to agent '{agent_name}': {str(e)}") - raise SystemExit(1) - - yield wrapper - - except PromptExitError as e: - # User requested exit - not an error, show usage report - self._handle_error(e) - raise SystemExit(0) - except ( - ServerConfigError, - ProviderKeyError, - AgentConfigError, - ServerInitializationError, - ModelConfigError, - CircularDependencyError, - ) as e: - had_error = True - self._handle_error(e) - raise SystemExit(1) - - finally: - # Print usage report before cleanup (show for user exits too) - if active_agents and not had_error: - self._print_usage_report(active_agents) - - # Clean up any active agents (always cleanup, even on errors) - if active_agents: - for agent in active_agents.values(): - try: - await agent.shutdown() - except Exception: - pass + sys.exit(1) + + # Start the background polling task to reactivate agents (only if polling is enabled) + polling_task = None + if self.mcp_polling_interval is not None and self.mcp_polling_interval > 0: + polling_task = asyncio.create_task(self._poll_and_reactivate_servers(agent_app)) + + yield agent_app + + except PromptExitError as e: + # User requested exit - not an error, show usage report + self._handle_error(e) + raise SystemExit(0) + except ( + ServerConfigError, + ProviderKeyError, + AgentConfigError, + ServerInitializationError, + ModelConfigError, + CircularDependencyError, + ) as e: + had_error = True + self._handle_error(e) + raise SystemExit(1) + + finally: + if polling_task: + polling_task.cancel() + try: + await polling_task + except asyncio.CancelledError: + logger.info("Agent reactivation polling task cancelled.") + + # Print usage report before cleanup (show for user exits too) + if active_agents and not had_error: + self._print_usage_report(active_agents) + + # Clean up any active agents (always cleanup, even on errors) + if active_agents: + for agent in active_agents.values(): + try: + await agent.shutdown() + except Exception: + pass + + logger.info("FastAgent run context finished.") + + async def _poll_and_reactivate_servers(self, agent_app: AgentApp): + """ + Periodically poll ALL MCP servers and set agent status based on server availability. + """ + while True: + await asyncio.sleep(self.mcp_polling_interval) # Poll at configurable interval + + # Get all servers used by all agents (active and deactivated) + all_servers = set() + + # Get servers from active agents + for agent in agent_app._agents.values(): + if hasattr(agent, 'server_names'): + all_servers.update(agent.server_names) + + # Get servers from deactivated agents + for agent_config in self.deactivated_agents.values(): + if hasattr(agent_config, 'servers'): + all_servers.update(agent_config.servers) + elif isinstance(agent_config, dict) and 'config' in agent_config: + config_obj = agent_config['config'] + if hasattr(config_obj, 'servers'): + all_servers.update(config_obj.servers) + + if not all_servers: + # No servers to poll + logger.info( + "Polling MCP Servers", + data={ + "progress_action": ProgressAction.READY, + "agent_name": "MCP Server Polling", + "target": "Polling MCP Servers", + "details": "No servers to monitor" + } + ) + continue + + # Update progress to show we're actively polling + logger.info( + "Polling MCP Servers", + data={ + "progress_action": ProgressAction.RUNNING, + "agent_name": "MCP Server Polling", + "target": "Polling MCP Servers", + "details": f"Checking {len(all_servers)} servers" + } + ) + + # Record start time for minimum display duration + polling_start_time = asyncio.get_event_loop().time() + + # Check health of all servers by directly testing connectivity + server_status = {} # server_name -> is_healthy (bool) + + for server_name in all_servers: + try: + # Test server connectivity directly + async with asyncio.timeout(10): # 10 second timeout + # Import here to avoid circular imports + from mcp_agent.mcp.gen_client import gen_client + from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession + + # Create a temporary connection to test server health + async with gen_client( + server_name, + server_registry=self.context.server_registry, + client_session_factory=MCPAgentClientSession + ) as client: + # Try to make a simple request to test connectivity + await client.list_tools() + server_status[server_name] = True + logger.debug(f"Server '{server_name}' health check: HEALTHY") + + except Exception as e: + server_status[server_name] = False + logger.debug(f"Server '{server_name}' health check: UNHEALTHY ({e})") + + # Update unavailable_servers set based on current server status + self.unavailable_servers.clear() + for server_name, is_healthy in server_status.items(): + if not is_healthy: + self.unavailable_servers.add(server_name) + + # Update agent status based on server availability + await self._update_agent_status_from_servers(server_status, agent_app) + + # Ensure minimum 1 second display time for "Running" status + elapsed_time = asyncio.get_event_loop().time() - polling_start_time + if elapsed_time < 1.0: + await asyncio.sleep(1.0 - elapsed_time) + + # Update progress to show polling cycle is complete + offline_servers = [name for name, healthy in server_status.items() if not healthy] + if offline_servers: + logger.info( + "Polling MCP Servers", + data={ + "progress_action": ProgressAction.READY, + "agent_name": "MCP Server Polling", + "target": "Polling MCP Servers", + "details": f"Waiting ({len(offline_servers)} offline: {', '.join(offline_servers)})" + } + ) + else: + logger.info( + "Polling MCP Servers", + data={ + "progress_action": ProgressAction.READY, + "agent_name": "MCP Server Polling", + "target": "Polling MCP Servers", + "details": "All servers online" + } + ) + + async def _update_agent_status_from_servers(self, server_status: Dict[str, bool], agent_app: AgentApp): + """ + Simple logic: + - If server is offline and agent is active -> deactivate agent + - If server is online and agent is deactivated -> reactivate agent + + Args: + server_status: Dictionary mapping server_name -> is_healthy (bool) + agent_app: The agent application instance + """ + # Check active agents - deactivate if any required server is offline + for agent_name, agent in list(agent_app._agents.items()): + if hasattr(agent, 'server_names'): + required_servers = agent.server_names + # If any required server is offline, deactivate the agent + if any(not server_status.get(server_name, False) for server_name in required_servers): + offline_servers = [s for s in required_servers if not server_status.get(s, False)] + + # Get agent config in the same format as stored in self.agents + # We need to reconstruct the agent configuration dictionary + if agent_name in self.agents: + # Use the original agent configuration from self.agents + agent_config = self.agents[agent_name] + else: + # Fallback: create a minimal config structure + agent_config = { + "config": getattr(agent, 'config', {}), + "type": "basic", # Default type + } + + # Move to deactivated agents + self.deactivated_agents[agent_name] = agent_config + + # Remove from active agents + del agent_app._agents[agent_name] + + logger.debug(f"Agent '{agent_name}' deactivated due to offline servers: {', '.join(offline_servers)}") + + # Log progress update + logger.info( + f"Agent '{agent_name}' deactivated", + data={ + "progress_action": ProgressAction.DEACTIVATED, + "agent_name": agent_name, + }, + ) + + # Check deactivated agents - reactivate if all required servers are online + for agent_name, agent_config in list(self.deactivated_agents.items()): + # Get required servers for this agent + required_servers = [] + if hasattr(agent_config, 'servers'): + required_servers = agent_config.servers + elif isinstance(agent_config, dict) and 'config' in agent_config: + config_obj = agent_config['config'] + if hasattr(config_obj, 'servers'): + required_servers = config_obj.servers + + # If all required servers are online, reactivate the agent + if required_servers and all(server_status.get(server_name, False) for server_name in required_servers): + logger.debug(f"All required servers available for agent '{agent_name}', attempting reactivation") + await self._reactivate_agent(agent_name, agent_config, agent_app) + + async def _reactivate_agent(self, agent_name: str, agent_config: Dict, agent_app: AgentApp): + """ + Reactivate a single agent that was previously offline. + """ + logger.debug(f"Attempting to reactivate agent: {agent_name}") + try: + # Define a model factory function for reactivation + def model_factory_func(model=None, request_params=None): + return get_model_factory( + self.context, + model=model, + request_params=request_params, + default_model=self.config.get("default_model"), + cli_model=self.args.model if hasattr(self.args, "model") else None, + ) + + # Create the agent + created_agents = await create_agents_in_dependency_order( + app_instance=self.app, + agents_dict={agent_name: agent_config}, + model_factory_func=model_factory_func, + ) + + if agent_name in created_agents: + # Add to the running agent_app + agent_app.add_agent(agent_name, created_agents[agent_name]) + # Remove from deactivated list + del self.deactivated_agents[agent_name] + + logger.debug(f"Agent '{agent_name}' has been successfully reactivated!") + else: + logger.error(f"Failed to create agent '{agent_name}' during reactivation.") + + except ServerInitializationError as e: + # This can happen if the server goes down again right as we try to reactivate + match = re.search(r"MCP Server: '([^']*)'", str(e)) + if match: + server_name = match.group(1) + self.unavailable_servers.add(server_name) + logger.info( + f"Server '{server_name}' became unavailable during reactivation of agent '{agent_name}'. " + "Reactivation will be re-attempted later." + ) + else: + logger.error(f"Could not determine server name from reactivation error: {e}") + + except Exception as e: + logger.error(f"Error reactivating agent '{agent_name}': {e}") + def _handle_error(self, e: Exception, error_type: Optional[str] = None) -> None: """ - Handle errors with consistent formatting and messaging. + Centralized error handling for the application. Args: e: The exception that was raised diff --git a/src/mcp_agent/event_progress.py b/src/mcp_agent/event_progress.py index 223d9cd1..9f824535 100644 --- a/src/mcp_agent/event_progress.py +++ b/src/mcp_agent/event_progress.py @@ -21,10 +21,13 @@ class ProgressAction(str, Enum): READY = "Ready" CALLING_TOOL = "Calling Tool" UPDATED = "Updated" + RUNNING = "Running" FINISHED = "Finished" SHUTDOWN = "Shutdown" AGGREGATOR_INITIALIZED = "Running" FATAL_ERROR = "Error" + ERROR = "Error" + DEACTIVATED = "Deactivated" class ProgressEvent(BaseModel): @@ -69,6 +72,11 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]: progress_action = event_data.get("progress_action") if not progress_action: return None + + # Filter out MCP server lifecycle events - only show agent events + if "mcp_connection_manager" in event.namespace and progress_action == ProgressAction.DEACTIVATED: + # Skip MCP server lifecycle events, only show agent deactivation events + return None # Build target string based on the event type. # Progress display is currently [time] [event] --- [target] [details] diff --git a/src/mcp_agent/logging/rich_progress.py b/src/mcp_agent/logging/rich_progress.py index 6ce73fb8..19a8ab7d 100644 --- a/src/mcp_agent/logging/rich_progress.py +++ b/src/mcp_agent/logging/rich_progress.py @@ -82,6 +82,7 @@ def _get_action_style(self, action: ProgressAction) -> str: ProgressAction.SHUTDOWN: "black on red", ProgressAction.AGGREGATOR_INITIALIZED: "bold green", ProgressAction.FATAL_ERROR: "black on red", + ProgressAction.DEACTIVATED: "dim yellow", }.get(action, "white") def update(self, event: ProgressEvent) -> None: @@ -120,6 +121,7 @@ def update(self, event: ProgressEvent) -> None: event.action == ProgressAction.INITIALIZED or event.action == ProgressAction.READY or event.action == ProgressAction.LOADED + or event.action == ProgressAction.DEACTIVATED ): self._progress.update(task_id, completed=100, total=100) elif event.action == ProgressAction.FINISHED: diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index 42d10063..ee06de9f 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -345,7 +345,8 @@ def create_session(read_stream, write_stream, read_timeout, **kwargs): for result in results: if isinstance(result, BaseException): logger.error(f"Error loading server data: {result}") - continue + # Re-raise the exception to propagate it up to the agent creation logic + raise result server_name, tools, prompts = result @@ -470,7 +471,7 @@ async def _execute_on_server( error_factory: Callable[[str], R] = None, ) -> R: """ - Generic method to execute operations on a specific server. + Generic method to execute operations on a specific server with retry logic. Args: server_name: Name of the server to execute the operation on @@ -483,49 +484,77 @@ async def _execute_on_server( Returns: Result from the operation or an error result """ - + import asyncio + async def try_execute(client: ClientSession): try: method = getattr(client, method_name) return await method(**method_args) except Exception as e: - error_msg = ( - f"Failed to {method_name} '{operation_name}' on server '{server_name}': {e}" - ) - logger.error(error_msg) - if error_factory: - return error_factory(error_msg) + # Re-raise the original exception to be handled by retry logic + raise e + + # Retry logic with exponential backoff + max_retries = 3 + base_delay = 0.5 # Start with 0.5 seconds + + for attempt in range(max_retries): + try: + if self.connection_persistence: + server_connection = await self._persistent_connection_manager.get_server( + server_name, client_session_factory=MCPAgentClientSession + ) + return await try_execute(server_connection.session) else: - # Re-raise the original exception to propagate it - raise e - - if self.connection_persistence: - server_connection = await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) - return await try_execute(server_connection.session) - else: - logger.debug( - f"Creating temporary connection to server: {server_name}", - data={ - "progress_action": ProgressAction.STARTING, - "server_name": server_name, - "agent_name": self.agent_name, - }, - ) - async with gen_client( - server_name, server_registry=self.context.server_registry - ) as client: - result = await try_execute(client) - logger.debug( - f"Closing temporary connection to server: {server_name}", - data={ - "progress_action": ProgressAction.SHUTDOWN, - "server_name": server_name, - "agent_name": self.agent_name, - }, - ) - return result + logger.debug( + f"Creating temporary connection to server: {server_name}", + data={ + "progress_action": ProgressAction.STARTING, + "server_name": server_name, + "agent_name": self.agent_name, + }, + ) + async with gen_client( + server_name, server_registry=self.context.server_registry + ) as client: + result = await try_execute(client) + logger.debug( + f"Closing temporary connection to server: {server_name}", + data={ + "progress_action": ProgressAction.SHUTDOWN, + "server_name": server_name, + "agent_name": self.agent_name, + }, + ) + return result + + except Exception as e: + is_last_attempt = (attempt == max_retries - 1) + + if is_last_attempt: + # Final failure - log and return error + error_msg = ( + f"Failed to {method_name} '{operation_name}' on server '{server_name}' after {max_retries} attempts: {e}" + ) + logger.debug(error_msg) + + # Check if this is a connection-related failure that should trigger agent deactivation + if self._is_connection_failure(e): + await self._handle_server_failure(server_name) + + if error_factory: + return error_factory(error_msg) + else: + # Re-raise the original exception to propagate it + raise e + else: + # Retry attempt - log and wait + delay = base_delay * (2 ** attempt) # Exponential backoff + logger.debug( + f"Attempt {attempt + 1} failed for {method_name} '{operation_name}' on server '{server_name}': {e}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]: """ @@ -1201,3 +1230,99 @@ async def list_resources(self, server_name: str | None = None) -> Dict[str, List logger.error(f"Error fetching resources from {s_name}: {e}") return results + + def _is_connection_failure(self, exception: Exception) -> bool: + """ + Check if an exception indicates a connection failure that should trigger agent deactivation. + + Args: + exception: The exception to check + + Returns: + True if this is a connection failure, False otherwise + """ + # Check for common connection-related exceptions + import httpx + + connection_error_types = ( + ConnectionError, + ConnectionRefusedError, + ConnectionResetError, + ConnectionAbortedError, + OSError, + httpx.ConnectError, + httpx.TimeoutException, + httpx.NetworkError, + ) + + # Check if it's a direct connection error + if isinstance(exception, connection_error_types): + return True + + # Check for connection-related error messages + error_msg = str(exception).lower() + connection_keywords = [ + "connection", + "timeout", + "network", + "refused", + "reset", + "aborted", + "peer closed", + "incomplete chunked read", + "server unavailable", + "transport error" + ] + + return any(keyword in error_msg for keyword in connection_keywords) + + async def _handle_server_failure(self, server_name: str) -> None: + """ + Handle server failure by notifying the FastAgent system to deactivate related agents. + + Args: + server_name: Name of the server that failed + """ + try: + # Try to get the FastAgent instance from context to report server failure + from mcp_agent.context import get_current_context + + context = get_current_context() + if context and hasattr(context, 'fast_agent'): + fast_agent = context.fast_agent + if hasattr(fast_agent, 'unavailable_servers'): + fast_agent.unavailable_servers.add(server_name) + + # Find agents that use this server and deactivate them + if hasattr(fast_agent, 'app') and hasattr(fast_agent.app, '_agents'): + agents_to_deactivate = [] + for agent_name, agent in fast_agent.app._agents.items(): + if hasattr(agent, 'server_names') and server_name in agent.server_names: + agents_to_deactivate.append(agent_name) + + # Move agents to deactivated list + for agent_name in agents_to_deactivate: + if agent_name in fast_agent.app._agents: + # Get agent config from active agents + agent_config = getattr(fast_agent.app._agents[agent_name], 'config', {}) + + # Move to deactivated agents + if hasattr(fast_agent, 'deactivated_agents'): + fast_agent.deactivated_agents[agent_name] = agent_config + + # Remove from active agents + del fast_agent.app._agents[agent_name] + + logger.debug(f"Agent '{agent_name}' deactivated due to server '{server_name}' failure during runtime.") + + # Log progress update + logger.info( + f"Agent '{agent_name}' deactivated", + data={ + "progress_action": ProgressAction.DEACTIVATED, + "agent_name": agent_name, + }, + ) + + except Exception as e: + logger.error(f"Error handling server failure for '{server_name}': {e}") diff --git a/src/mcp_agent/mcp/mcp_connection_manager.py b/src/mcp_agent/mcp/mcp_connection_manager.py index 3cb531e9..723dcb47 100644 --- a/src/mcp_agent/mcp/mcp_connection_manager.py +++ b/src/mcp_agent/mcp/mcp_connection_manager.py @@ -3,7 +3,6 @@ """ import asyncio -import traceback from datetime import timedelta from typing import ( TYPE_CHECKING, @@ -32,7 +31,6 @@ from mcp_agent.event_progress import ProgressAction from mcp_agent.logging.logger import get_logger from mcp_agent.mcp.logger_textio import get_stderr_handler -from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession if TYPE_CHECKING: from mcp_agent.context import Context @@ -104,7 +102,7 @@ def __init__( # Track error state self._error_occurred = False - self._error_message = None + self._original_exception: Optional[Exception] = None def is_healthy(self) -> bool: """Check if the server connection is healthy and ready to use.""" @@ -191,12 +189,13 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: await server_conn.wait_for_shutdown_request() except HTTPStatusError as http_exc: - logger.error( - f"{server_name}: Lifecycle task encountered HTTP error: {http_exc}", - exc_info=True, + logger.info( + f"{server_name}: MCP server unavailable", data={ - "progress_action": ProgressAction.FATAL_ERROR, + "progress_action": ProgressAction.DEACTIVATED, "server_name": server_name, + "agent_name": server_name, + "target": server_name, }, ) server_conn._error_occurred = True @@ -205,53 +204,17 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: # No raise - let get_server handle it with a friendly message except Exception as exc: - logger.error( - f"{server_name}: Lifecycle task encountered an error: {exc}", - exc_info=True, + logger.info( + f"{server_name}: MCP server unavailable", data={ - "progress_action": ProgressAction.FATAL_ERROR, + "progress_action": ProgressAction.DEACTIVATED, "server_name": server_name, + "agent_name": server_name, + "target": server_name, }, ) server_conn._error_occurred = True - - if "ExceptionGroup" in type(exc).__name__ and hasattr(exc, "exceptions"): - # Handle ExceptionGroup better by extracting the actual errors - def extract_errors(exception_group): - """Recursively extract meaningful errors from ExceptionGroups""" - messages = [] - for subexc in exception_group.exceptions: - if "ExceptionGroup" in type(subexc).__name__ and hasattr(subexc, "exceptions"): - # Recursively handle nested ExceptionGroups - messages.extend(extract_errors(subexc)) - elif isinstance(subexc, HTTPStatusError): - # Special handling for HTTP errors to make them more user-friendly - messages.append( - f"HTTP Error: {subexc.response.status_code} {subexc.response.reason_phrase} for URL: {subexc.request.url}" - ) - else: - # Show the exception type and message, plus the root cause if available - error_msg = f"{type(subexc).__name__}: {subexc}" - messages.append(error_msg) - - # If there's a root cause, show that too as it's often the most informative - if hasattr(subexc, "__cause__") and subexc.__cause__: - messages.append( - f"Caused by: {type(subexc.__cause__).__name__}: {subexc.__cause__}" - ) - return messages - - error_messages = extract_errors(exc) - # If we didn't extract any meaningful errors, fall back to the original exception - if not error_messages: - error_messages = [f"{type(exc).__name__}: {exc}"] - server_conn._error_message = error_messages - else: - # For regular exceptions, keep the traceback but format it more cleanly - server_conn._error_message = traceback.format_exception(exc) - - # If there's an error, we should also set the event so that - # 'get_server' won't hang + server_conn._original_exception = exc server_conn._initialized_event.set() # No raise - allow graceful exit @@ -405,28 +368,17 @@ async def get_server( # Check if the server is healthy after initialization if not server_conn.is_healthy(): - error_msg = server_conn._error_message or "Unknown error" - - # Format the error message for better display - if isinstance(error_msg, list): - # Join the list with newlines for better readability - formatted_error = "\n".join(error_msg) - else: - formatted_error = str(error_msg) - raise ServerInitializationError( - f"MCP Server: '{server_name}': Failed to initialize - see details. Check fastagent.config.yaml?", - formatted_error, - ) + f"MCP Server: '{server_name}': Failed to initialize - see details. Check fastagent.config.yaml?" + ) from server_conn._original_exception return server_conn async def get_server_capabilities(self, server_name: str) -> ServerCapabilities | None: """Get the capabilities of a specific server.""" - server_conn = await self.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) - return server_conn.server_capabilities if server_conn else None + async with self._lock: + server_conn = self.running_servers.get(server_name) + return server_conn.server_capabilities if server_conn else None async def disconnect_server(self, server_name: str) -> None: """