diff --git a/Makefile b/Makefile index 02154af..0587e38 100644 --- a/Makefile +++ b/Makefile @@ -43,8 +43,10 @@ coverage-html: coverage open htmlcov/index.html || echo "Open htmlcov/index.html in your browser to view the coverage report." .PHONY: container-local +DOCKER_TAG := $(shell git branch --show-current | tr -c '[:alnum:]._-' '-') + container-local: - docker buildx build -t ghcr.io/cisco-foundation-ai/peak-assistant:$(shell git branch --show-current) --load . + docker buildx build -t ghcr.io/cisco-foundation-ai/peak-assistant:$(DOCKER_TAG) --load . .PHONY: container-run @@ -56,4 +58,4 @@ container-run: container-local --mount "type=bind,src=$(PWD)/.env,target=/home/peakassistant/.env" \ --mount "type=bind,src=$(PWD)/mcp_servers.json,target=/home/peakassistant/mcp_servers.json" \ -p "127.0.0.1:8501:8501" \ - ghcr.io/cisco-foundation-ai/peak-assistant:$(shell git branch --show-current) \ No newline at end of file + ghcr.io/cisco-foundation-ai/peak-assistant:$(DOCKER_TAG) diff --git a/evaluations/research-agent-team-eval/evaluator.py b/evaluations/research-agent-team-eval/evaluator.py index d00a4d0..ef90599 100755 --- a/evaluations/research-agent-team-eval/evaluator.py +++ b/evaluations/research-agent-team-eval/evaluator.py @@ -36,6 +36,7 @@ import argparse import asyncio import aiohttp +import ipaddress from typing import Any, Dict, List, Tuple, Optional from collections import defaultdict from dataclasses import dataclass, field @@ -985,8 +986,46 @@ async def evaluate_url_validity_async( urls = re.findall(r"https?://[^\s\)]+", report) + def is_safe_public_url(url: str) -> bool: + """Block private/internal URL targets to reduce SSRF risk.""" + try: + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + return False + if not parsed.hostname: + return False + + hostname = parsed.hostname.strip("[]").lower() + if hostname in {"localhost", "localhost.localdomain"}: + return False + if hostname.endswith(".local"): + return False + + try: + ip = ipaddress.ip_address(hostname) + if ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + or ip.is_unspecified + ): + return False + except ValueError: + # Hostname is not a direct IP literal; keep it eligible. + pass + + return True + except Exception: + return False + + safe_urls = [url for url in urls if is_safe_public_url(url)] + blocked_urls = [url for url in urls if not is_safe_public_url(url)] + results = { "total_urls": len(urls), + "blocked_urls": len(blocked_urls), "valid_urls": 0, "invalid_urls": 0, "timeout_urls": 0, @@ -995,19 +1034,19 @@ async def evaluate_url_validity_async( } # Decide which URLs to check - if len(urls) <= 20: + if len(safe_urls) <= 20: # Check all URLs if 20 or fewer - urls_to_check = urls - results["sample_size"] = len(urls) + urls_to_check = safe_urls + results["sample_size"] = len(safe_urls) else: # Random sample of 20 if more than 20 - urls_to_check = random.sample(urls, 20) + urls_to_check = random.sample(safe_urls, 20) results["sample_size"] = 20 async def check_url(session, url): try: async with session.head( - url, timeout=5, allow_redirects=True + url, timeout=5, allow_redirects=False ) as response: if response.status < 400: return url, "valid" @@ -1039,13 +1078,15 @@ async def check_url(session, url): score = 0 # Update feedback to indicate sampling - if len(urls) > 20: - feedback = f"{results['valid_urls']}/{results['sample_size']} URLs valid (random sample from {len(urls)} total)" + if len(safe_urls) > 20: + feedback = f"{results['valid_urls']}/{results['sample_size']} URLs valid (random sample from {len(safe_urls)} eligible)" else: feedback = f"{results['valid_urls']}/{results['sample_size']} URLs valid" if results["broken_links"]: feedback += f", {len(results['broken_links'])} broken" + if blocked_urls: + feedback += f", skipped {len(blocked_urls)} private/internal URL(s)" return MetricResult(score=score, details=results, feedback=feedback) diff --git a/peak_assistant/peak_mcp/__main__.py b/peak_assistant/peak_mcp/__main__.py index 4ca5b54..c73df73 100755 --- a/peak_assistant/peak_mcp/__main__.py +++ b/peak_assistant/peak_mcp/__main__.py @@ -296,7 +296,12 @@ async def hypothesizer( """ try: user_input = "" - result = await async_hypothesizer(user_input, research_document, local_context, local_data_search_results) + result = await async_hypothesizer( + user_input=user_input, + research_document=research_document, + local_data_document=local_data_search_results, + local_context=local_context, + ) return embeddable_object(data=result) except Exception as e: return embeddable_object(data=f"Error during hypothesis generation: {str(e)}") diff --git a/peak_assistant/streamlit/app.py b/peak_assistant/streamlit/app.py index 03d8a3e..801f384 100644 --- a/peak_assistant/streamlit/app.py +++ b/peak_assistant/streamlit/app.py @@ -351,10 +351,10 @@ if st.session_state["last_hypothesis_for_data_discovery"] != current_hypothesis: st.session_state["last_hypothesis_for_data_discovery"] = current_hypothesis # Clear the data sources document to show run button - if "Data Sources_document" in st.session_state: - del st.session_state["Data Sources_document"] - if "Data Sources_messages" in st.session_state: - del st.session_state["Data Sources_messages"] + if "Discovery_document" in st.session_state: + del st.session_state["Discovery_document"] + if "Discovery_messages" in st.session_state: + del st.session_state["Discovery_messages"] peak_assistant_chat( title="Data Discovery", diff --git a/peak_assistant/utils/mcp_config.py b/peak_assistant/utils/mcp_config.py index 590d592..b344d3d 100644 --- a/peak_assistant/utils/mcp_config.py +++ b/peak_assistant/utils/mcp_config.py @@ -957,13 +957,27 @@ async def _connect_stdio_server(self, server_name: str, config: MCPServerConfig, # Create stdio server parameters # Fix: Ensure args is properly formatted for StdioServerParams args_list = config.args or [] + timeout_seconds = 30.0 + try: + timeout_seconds = float(config.timeout) + except (TypeError, ValueError): + logger.warning( + "Invalid timeout value %r for stdio server %s; using default timeout %.1f seconds", + config.timeout, + server_name, + timeout_seconds, + ) + server_params = StdioServerParams( command=config.command, args=args_list, env=env, - read_timeout_seconds=float(config.timeout) + read_timeout_seconds=timeout_seconds + ) + logger.debug( + f"Created StdioServerParams with command: {config.command}, args: {args_list}, " + f"read_timeout_seconds: {timeout_seconds}, env keys: {list(env.keys()) if env else 'None'}" ) - logger.debug(f"Created StdioServerParams with command: {config.command}, args: {args_list}, read_timeout_seconds: {float(config.timeout)}, env keys: {list(env.keys()) if env else 'None'}") # Create workbench # Fix: Pass server_params directly, not as a list diff --git a/tests/unit_tests/test_mcp_stdio_timeout.py b/tests/unit_tests/test_mcp_stdio_timeout.py new file mode 100644 index 0000000..1436936 --- /dev/null +++ b/tests/unit_tests/test_mcp_stdio_timeout.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 Cisco Systems, Inc. and its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# SPDX-License-Identifier: MIT + +"""Tests for robust timeout handling in MCP stdio connections.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from peak_assistant.utils.mcp_config import MCPClientManager, MCPServerConfig, TransportType + + +@pytest.fixture +def client_manager(): + """Create an MCPClientManager with a lightweight mocked config manager.""" + manager = MagicMock() + manager.user_session_manager = MagicMock() + return MCPClientManager(manager) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("timeout_value", [None, "not-a-number"]) +async def test_connect_stdio_server_uses_default_timeout_for_invalid_values(client_manager, timeout_value): + """Invalid timeout config should not crash stdio server setup.""" + config = MCPServerConfig( + name="test-server", + transport=TransportType.STDIO, + command="echo", + args=["ok"], + timeout=timeout_value, + ) + + captured = {} + + class _FakeWorkbench: + async def __aenter__(self): + return self + + def _capture_workbench(server_params): + captured["read_timeout_seconds"] = server_params.read_timeout_seconds + return _FakeWorkbench() + + with patch("peak_assistant.utils.mcp_config.McpWorkbench", side_effect=_capture_workbench): + result = await client_manager._connect_stdio_server("test-server", config) + + assert result is True + assert captured["read_timeout_seconds"] == 30.0