Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ coverage-html: coverage
open htmlcov/index.html || echo "Open htmlcov/index.html in your browser to view the coverage report."

.PHONY: container-local
DOCKER_TAG := $(shell git branch --show-current | tr -c '[:alnum:]._-' '-')

container-local:
docker buildx build -t ghcr.io/cisco-foundation-ai/peak-assistant:$(shell git branch --show-current) --load .
docker buildx build -t ghcr.io/cisco-foundation-ai/peak-assistant:$(DOCKER_TAG) --load .


.PHONY: container-run
Expand All @@ -56,4 +58,4 @@ container-run: container-local
--mount "type=bind,src=$(PWD)/.env,target=/home/peakassistant/.env" \
--mount "type=bind,src=$(PWD)/mcp_servers.json,target=/home/peakassistant/mcp_servers.json" \
-p "127.0.0.1:8501:8501" \
ghcr.io/cisco-foundation-ai/peak-assistant:$(shell git branch --show-current)
ghcr.io/cisco-foundation-ai/peak-assistant:$(DOCKER_TAG)
55 changes: 48 additions & 7 deletions evaluations/research-agent-team-eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import argparse
import asyncio
import aiohttp
import ipaddress
from typing import Any, Dict, List, Tuple, Optional
from collections import defaultdict
from dataclasses import dataclass, field
Expand Down Expand Up @@ -985,8 +986,46 @@ async def evaluate_url_validity_async(

urls = re.findall(r"https?://[^\s\)]+", report)

def is_safe_public_url(url: str) -> bool:
"""Block private/internal URL targets to reduce SSRF risk."""
try:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
return False
if not parsed.hostname:
return False

hostname = parsed.hostname.strip("[]").lower()
if hostname in {"localhost", "localhost.localdomain"}:
return False
if hostname.endswith(".local"):
return False

try:
ip = ipaddress.ip_address(hostname)
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
return False
except ValueError:
# Hostname is not a direct IP literal; keep it eligible.
pass

return True
except Exception:
return False

safe_urls = [url for url in urls if is_safe_public_url(url)]
blocked_urls = [url for url in urls if not is_safe_public_url(url)]

results = {
"total_urls": len(urls),
"blocked_urls": len(blocked_urls),
"valid_urls": 0,
"invalid_urls": 0,
"timeout_urls": 0,
Expand All @@ -995,19 +1034,19 @@ async def evaluate_url_validity_async(
}

# Decide which URLs to check
if len(urls) <= 20:
if len(safe_urls) <= 20:
# Check all URLs if 20 or fewer
urls_to_check = urls
results["sample_size"] = len(urls)
urls_to_check = safe_urls
results["sample_size"] = len(safe_urls)
Comment on lines +1037 to +1040

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Count blocked URLs in validity score denominator

The URL validity score now samples only safe_urls, so private/internal links are excluded from both sample_size and invalid_urls. In reports that mix blocked links with a few reachable public links, this can produce an artificially high URL-validity score (e.g., 1 valid public URL + many blocked URLs can still score near 100), which skews evaluation results. To preserve metric integrity while keeping SSRF protections, blocked URLs should still contribute to the invalid/denominator side of scoring.

Useful? React with 👍 / 👎.

else:
# Random sample of 20 if more than 20
urls_to_check = random.sample(urls, 20)
urls_to_check = random.sample(safe_urls, 20)
results["sample_size"] = 20

async def check_url(session, url):
try:
async with session.head(
url, timeout=5, allow_redirects=True
url, timeout=5, allow_redirects=False
) as response:
if response.status < 400:
return url, "valid"
Expand Down Expand Up @@ -1039,13 +1078,15 @@ async def check_url(session, url):
score = 0

# Update feedback to indicate sampling
if len(urls) > 20:
feedback = f"{results['valid_urls']}/{results['sample_size']} URLs valid (random sample from {len(urls)} total)"
if len(safe_urls) > 20:
feedback = f"{results['valid_urls']}/{results['sample_size']} URLs valid (random sample from {len(safe_urls)} eligible)"
else:
feedback = f"{results['valid_urls']}/{results['sample_size']} URLs valid"

if results["broken_links"]:
feedback += f", {len(results['broken_links'])} broken"
if blocked_urls:
feedback += f", skipped {len(blocked_urls)} private/internal URL(s)"

return MetricResult(score=score, details=results, feedback=feedback)

Expand Down
7 changes: 6 additions & 1 deletion peak_assistant/peak_mcp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,12 @@ async def hypothesizer(
"""
try:
user_input = ""
result = await async_hypothesizer(user_input, research_document, local_context, local_data_search_results)
result = await async_hypothesizer(
user_input=user_input,
research_document=research_document,
local_data_document=local_data_search_results,
local_context=local_context,
)
return embeddable_object(data=result)
except Exception as e:
return embeddable_object(data=f"Error during hypothesis generation: {str(e)}")
Expand Down
8 changes: 4 additions & 4 deletions peak_assistant/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,10 @@
if st.session_state["last_hypothesis_for_data_discovery"] != current_hypothesis:
st.session_state["last_hypothesis_for_data_discovery"] = current_hypothesis
# Clear the data sources document to show run button
if "Data Sources_document" in st.session_state:
del st.session_state["Data Sources_document"]
if "Data Sources_messages" in st.session_state:
del st.session_state["Data Sources_messages"]
if "Discovery_document" in st.session_state:
del st.session_state["Discovery_document"]
if "Discovery_messages" in st.session_state:
del st.session_state["Discovery_messages"]

peak_assistant_chat(
title="Data Discovery",
Expand Down
18 changes: 16 additions & 2 deletions peak_assistant/utils/mcp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,13 +957,27 @@ async def _connect_stdio_server(self, server_name: str, config: MCPServerConfig,
# Create stdio server parameters
# Fix: Ensure args is properly formatted for StdioServerParams
args_list = config.args or []
timeout_seconds = 30.0
try:
timeout_seconds = float(config.timeout)
except (TypeError, ValueError):
logger.warning(
"Invalid timeout value %r for stdio server %s; using default timeout %.1f seconds",
config.timeout,
server_name,
timeout_seconds,
)

server_params = StdioServerParams(
command=config.command,
args=args_list,
env=env,
read_timeout_seconds=float(config.timeout)
read_timeout_seconds=timeout_seconds
)
logger.debug(
f"Created StdioServerParams with command: {config.command}, args: {args_list}, "
f"read_timeout_seconds: {timeout_seconds}, env keys: {list(env.keys()) if env else 'None'}"
)
logger.debug(f"Created StdioServerParams with command: {config.command}, args: {args_list}, read_timeout_seconds: {float(config.timeout)}, env keys: {list(env.keys()) if env else 'None'}")

# Create workbench
# Fix: Pass server_params directly, not as a list
Expand Down
66 changes: 66 additions & 0 deletions tests/unit_tests/test_mcp_stdio_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2025 Cisco Systems, Inc. and its affiliates
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# SPDX-License-Identifier: MIT

"""Tests for robust timeout handling in MCP stdio connections."""

from unittest.mock import MagicMock, patch

import pytest

from peak_assistant.utils.mcp_config import MCPClientManager, MCPServerConfig, TransportType


@pytest.fixture
def client_manager():
"""Create an MCPClientManager with a lightweight mocked config manager."""
manager = MagicMock()
manager.user_session_manager = MagicMock()
return MCPClientManager(manager)


@pytest.mark.asyncio
@pytest.mark.parametrize("timeout_value", [None, "not-a-number"])
async def test_connect_stdio_server_uses_default_timeout_for_invalid_values(client_manager, timeout_value):
"""Invalid timeout config should not crash stdio server setup."""
config = MCPServerConfig(
name="test-server",
transport=TransportType.STDIO,
command="echo",
args=["ok"],
timeout=timeout_value,
)

captured = {}

class _FakeWorkbench:
async def __aenter__(self):
return self

def _capture_workbench(server_params):
captured["read_timeout_seconds"] = server_params.read_timeout_seconds
return _FakeWorkbench()

with patch("peak_assistant.utils.mcp_config.McpWorkbench", side_effect=_capture_workbench):
result = await client_manager._connect_stdio_server("test-server", config)

assert result is True
assert captured["read_timeout_seconds"] == 30.0
Loading