Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
release_data_handle_tool,
)
from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.diagnose import diagnose_residuals_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool
from sktime_mcp.tools.fit_predict import (
fit_predict_async_tool,
Expand Down Expand Up @@ -375,6 +376,29 @@ async def list_tools() -> list[Tool]:
"required": ["estimator_handle", "dataset"],
},
),
Tool(
name="diagnose_residuals",
description="Diagnose residuals of a fitted estimator using statistical tests",
inputSchema={
"type": "object",
"properties": {
"estimator_handle": {
"type": "string",
"description": "Handle from instantiate_estimator",
},
"dataset": {
"type": "string",
"description": "Dataset name: airline, sunspots, lynx, etc.",
},
"significance_level": {
"type": "number",
"description": "Alpha for statistical tests (default 0.05)",
"default": 0.05,
},
},
"required": ["estimator_handle", "dataset"],
},
),
# -- Data ------------------------------------------------------------
Tool(
name="list_available_data",
Expand Down Expand Up @@ -695,6 +719,13 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
arguments.get("cv_folds", 3),
)

elif name == "diagnose_residuals":
result = diagnose_residuals_tool(
arguments["estimator_handle"],
arguments["dataset"],
arguments.get("significance_level", 0.05),
)

elif name == "validate_pipeline":
validator = get_composition_validator()
validation = validator.validate_pipeline(arguments["components"])
Expand Down
2 changes: 2 additions & 0 deletions src/sktime_mcp/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
release_data_handle_tool,
)
from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.diagnose import diagnose_residuals_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool
from sktime_mcp.tools.fit_predict import (
fit_predict_async_tool,
Expand Down Expand Up @@ -44,6 +45,7 @@
"fit_predict_tool",
"fit_predict_async_tool",
"evaluate_estimator_tool",
"diagnose_residuals_tool",
"load_data_source_tool",
"load_data_source_async_tool",
"release_data_handle_tool",
Expand Down
169 changes: 169 additions & 0 deletions src/sktime_mcp/tools/diagnose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
diagnose tool for sktime MCP.

Executes post-fit residual diagnostics on a fitted estimator and
translates statistical test results into actionable LLM hints.
"""

import logging
from typing import Any

import numpy as np
from scipy.stats import shapiro
from statsmodels.stats.diagnostic import acorr_ljungbox

from sktime_mcp.runtime.executor import get_executor

logger = logging.getLogger(__name__)


def diagnose_residuals_tool(
estimator_handle: str,
dataset: str,
significance_level: float = 0.05,
) -> dict[str, Any]:
"""
Diagnose residuals of a fitted estimator.

Computes in-sample residuals and runs statistical tests (Ljung-Box,
Shapiro-Wilk, bias) to give an LLM actionable reasoning about why
the model failed and what to try next.

Args:
estimator_handle: Handle ID from instantiate_estimator
dataset: Name of demo dataset (e.g. 'airline', 'sunspots')
significance_level: Alpha threshold for statistical tests (default 0.05)

Returns:
Dictionary with success status, diagnostics dict, and llm_hint string
"""
executor = get_executor()

# --- 1. Retrieve the fitted instance ---
try:
instance = executor._handle_manager.get_instance(estimator_handle)
except KeyError:
return {"success": False, "error": f"Handle not found: {estimator_handle}"}

if not executor._handle_manager.is_fitted(estimator_handle):
return {
"success": False,
"error": f"Estimator {estimator_handle} is not fitted. Call fit_predict first.",
}

# --- 2. Reload the training data (same pattern as evaluate.py) ---
data_result = executor.load_dataset(dataset)
if not data_result["success"]:
return data_result

y = data_result["data"]

# --- 3. Compute in-sample residuals (primary + fallback) ---
try:
residuals = instance.predict_residuals(y=y)
except Exception:
try:
preds = instance.predict(fh=y.index)
residuals = y - preds
except Exception as e:
return {"success": False, "error": f"Failed to compute residuals: {e}"}

# Sanitize: some models (ARIMA, differencing) produce leading NaNs
residuals = residuals.dropna()
res_array = np.asarray(residuals, dtype=float)

if len(res_array) < 3:
return {
"success": False,
"error": "Not enough residual data points to run statistical tests.",
}

# --- 4. Statistical tests ---
diagnostics: dict[str, Any] = {}
llm_hints: list[str] = []

# Bias (mean error)
mean_error = float(np.mean(res_array))
std_error = float(np.std(res_array))

if mean_error > 0.1 * std_error:
bias_status = "consistently under-forecasting"
llm_hints.append(
f"Model has a persistent positive bias ({mean_error:.4f}), "
"consistently under-forecasting the actual values. "
"Consider adding a drift term or using an unbiased baseline."
)
elif mean_error < -0.1 * std_error:
bias_status = "consistently over-forecasting"
llm_hints.append(
f"Model has a persistent negative bias ({mean_error:.4f}), "
"consistently over-forecasting the actual values. "
"Consider adding a drift term or using an unbiased baseline."
)
else:
bias_status = "unbiased"

diagnostics["bias"] = {
"mean_error": mean_error,
"status": bias_status,
}

# Autocorrelation (Ljung-Box)
lb_lags = max(2, min(10, len(res_array) // 5))
try:
lb_res = acorr_ljungbox(res_array, lags=[lb_lags], return_df=True)
lb_p_value = float(lb_res["lb_pvalue"].iloc[0])
lb_passed = lb_p_value >= significance_level

diagnostics["autocorrelation"] = {
"ljung_box_passed": lb_passed,
"p_value": lb_p_value,
"tested_lags": lb_lags,
}

if not lb_passed:
llm_hints.append(
f"Residuals show significant autocorrelation at lag {lb_lags} "
f"(p={lb_p_value:.4f}). The model failed to capture time dependence "
"or seasonality. Consider switching to SARIMA or adding a Deseasonalizer pipeline."
)
except Exception as e:
logger.warning(f"Ljung-Box test failed: {e}")
diagnostics["autocorrelation"] = {"error": str(e)}

# Normality (Shapiro-Wilk), capped at 500 samples
shapiro_data = res_array[-500:] if len(res_array) > 500 else res_array
try:
_, shapiro_p_value = shapiro(shapiro_data)
shapiro_p_value = float(shapiro_p_value)
shapiro_passed = shapiro_p_value >= significance_level

diagnostics["normality"] = {
"shapiro_passed": shapiro_passed,
"p_value": shapiro_p_value,
}

if not shapiro_passed:
llm_hints.append(
f"Residuals are not normally distributed (Shapiro p={shapiro_p_value:.4f}), "
"suggesting heavy-tailed noise or non-linear effects not captured by the current model."
)
except Exception as e:
logger.warning(f"Shapiro-Wilk test failed: {e}")
diagnostics["normality"] = {"error": str(e)}

# --- 5. Build final LLM hint ---
if llm_hints:
final_hint = " ".join(llm_hints)
else:
final_hint = (
"Residuals look like white noise: no significant autocorrelation, "
"no strong bias, and approximately normal distribution. "
"The model has captured the underlying signal well."
)

return {
"success": True,
"diagnostics": diagnostics,
"llm_hint": final_hint,
}
58 changes: 58 additions & 0 deletions tests/test_diagnose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Tests for the diagnose_residuals_tool."""

from sktime_mcp.tools.diagnose import diagnose_residuals_tool
from sktime_mcp.tools.fit_predict import fit_predict_tool
from sktime_mcp.tools.instantiate import instantiate_estimator_tool


def test_diagnose_residuals_tool_success():
"""Test successful residual diagnosis of a fitted NaiveForecaster."""
inst_result = instantiate_estimator_tool("sktime.forecasting.naive.NaiveForecaster")
assert inst_result["success"] is True
handle_id = inst_result["handle"]

fit_result = fit_predict_tool(
estimator_handle=handle_id,
dataset="airline",
horizon=1,
)
assert fit_result["success"] is True

diag_result = diagnose_residuals_tool(
estimator_handle=handle_id,
dataset="airline",
)

assert diag_result["success"] is True
assert "diagnostics" in diag_result
assert "bias" in diag_result["diagnostics"]
assert "autocorrelation" in diag_result["diagnostics"]
assert "llm_hint" in diag_result
assert isinstance(diag_result["llm_hint"], str)
assert len(diag_result["llm_hint"]) > 0


def test_diagnose_not_fitted():
"""Test that diagnosis returns a clear error on an unfitted estimator."""
inst_result = instantiate_estimator_tool("sktime.forecasting.naive.NaiveForecaster")
assert inst_result["success"] is True
handle_id = inst_result["handle"]

diag_result = diagnose_residuals_tool(
estimator_handle=handle_id,
dataset="airline",
)

assert diag_result["success"] is False
assert "not fitted" in diag_result["error"].lower()


def test_diagnose_invalid_handle():
"""Test that diagnosis returns a clear error for a non-existent handle."""
diag_result = diagnose_residuals_tool(
estimator_handle="invalid_handle_xyz",
dataset="airline",
)

assert diag_result["success"] is False
assert "not found" in diag_result["error"].lower()