Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ TODO.md
dist/
docs/_build/
docs/build/
.ignore/
venv/
agentic_forecaster/data/train.csv
.ignore/
41 changes: 41 additions & 0 deletions agentic_forecaster/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Agentic Forecaster for sktime-mcp

This module implements an **Agentic Forecaster** that leverages the `sktime-mcp` Model Context Protocol (MCP) server to perform intelligent model selection and forecasting based on natural language prompts.

## Overview

Traditional forecasting workflows require manual model selection based on data characteristics. The `AgenticForecaster` automates this by:
1. **Semantic Reasoning**: Analyzing the user's prompt (e.g., "seasonal data", "fast execution") against the `sktime` registry's capability tags.
2. **Dynamic Tool Use**: Interacting with the MCP server to instantiate models, load data, and execute forecasts.
3. **Exogenous Support**: Automatically handling covariates (X) when provided, enabling professional-grade forecasting on complex datasets.

## Key Components

- **`agent.py`**: The core `AgenticForecaster` class. It manages the registry-driven reasoning and the interface with the MCP execution engine.
- **`main.py`**: A demonstration script showcasing the agent's ability to forecast retail sales data (Corporación Favorita dataset) using a simple English prompt.

## Usage Example

```python
from agentic_forecaster.agent import AgenticForecaster

# Initialize the agent
agent = AgenticForecaster()

# Execute a forecast with a natural language requirement
result = agent.fit_predict(
prompt="I need a model that handles seasonality and is fast to train.",
dataset="favorita_subset", # Or use a data_handle
horizon=30
)

print(f"Selected Model: {result['selected_model']}")
print(f"Explanation: {result['explanation']}")
```

## Contributions to sktime-mcp

This agentic workflow drove several core improvements to the `sktime-mcp` project:
- **Exogenous Support**: Added `exog_handle` support to the `fit_predict` tool stack to enable covariates in agentic workflows.
- **Evaluation Logic**: Fixed cross-validation fold calculation bugs in the `evaluate` tool to ensure agents receive accurate performance metrics.
- **Registry Visibility**: Improved docstring handling to ensure the agent can read full model descriptions for better decision making.
75 changes: 75 additions & 0 deletions agentic_forecaster/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pandas as pd
import logging
from typing import Optional, Any
from sktime.forecasting.base import BaseForecaster
from sktime_mcp.registry.interface import get_registry

logger = logging.getLogger(__name__)

class AgenticForecaster(BaseForecaster):
"""
An agentic forecaster that selects and configures an sktime estimator
based on a natural language prompt and data characteristics.
"""

_tags = {
"scitype:y": "both",
"capability:pred_int": True,
"requires-fh-in-fit": False,
"X-y-must-have same-index": True,
}

def __init__(self, prompt: str, llm_client: Any = None):
self.prompt = prompt
self.llm_client = llm_client
self.estimator_ = None
self.selected_model_name_ = None
self.explanation_ = None
super().__init__()

def _fit(self, y: pd.Series | pd.DataFrame, X: pd.DataFrame | None = None, fh: Any | None = None):
"""
Logic for selecting an estimator based on the prompt and fitting it.
"""
registry = get_registry()
lower_prompt = self.prompt.lower()
tags_to_query = {}

# Heuristic reasoning based on prompt keywords
if "interval" in lower_prompt or "probabilistic" in lower_prompt:
tags_to_query["capability:pred_int"] = True

if "multivariate" in lower_prompt:
tags_to_query["scitype:y"] = "multivariate"
else:
tags_to_query["scitype:y"] = "univariate"

estimators = registry.get_all_estimators(task="forecasting", tags=tags_to_query)
if not estimators:
# Fallback to all forecasters if tags are too restrictive
estimators = registry.get_all_estimators(task="forecasting")

# Select model based on name hints or default to first match
if "arima" in lower_prompt:
selected_node = next((e for e in estimators if "ARIMA" in e.name), estimators[0])
else:
selected_node = next((e for e in estimators if "AutoARIMA" in e.name), estimators[0])

self.selected_model_name_ = selected_node.name
self.estimator_ = selected_node.class_ref()
self.explanation_ = f"Selected {self.selected_model_name_} based on requirement: {self.prompt}"

self.estimator_.fit(y, X=X, fh=fh)
return self

def _predict(self, fh: Any | None = None, X: pd.DataFrame | None = None):
"""
Generate predictions using the agent-selected estimator.
"""
return self.estimator_.predict(fh=fh, X=X)

def explain(self) -> str:
"""
Return a natural language explanation of why the model was selected.
"""
return self.explanation_ or "No model selected yet."
26 changes: 26 additions & 0 deletions agentic_forecaster/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pandas as pd
from sktime.forecasting.base import ForecastingHorizon
import os
from agent import AgenticForecaster

script_dir = os.path.dirname(os.path.abspath(__file__))
train_path = os.path.join(script_dir, "data", "train.csv")
df = pd.read_csv(train_path)
df["date"] = pd.to_datetime(df["date"])
df = df[(df["store_nbr"] == 1) & (df["family"] == "GROCERY I")].copy()
df = df.groupby("date")["sales"].sum().sort_index().tail(90)
y = pd.Series(df)

fh = ForecastingHorizon(
pd.date_range(start=y.index[-1] + pd.Timedelta(days=1), periods=30, freq="D"),
is_relative=False
)

prompt = "Forecast grocery sales for the next 30 days. I need prediction intervals for uncertainty."
model = AgenticForecaster(prompt=prompt)
print(f"Agentic Prompt: {prompt}")
model.fit(y)
print(f"Agent Logic: {model.explain()}")
predictions = model.predict(fh)
print("\nNext 30 Days Forecast:\n")
print(predictions)
35 changes: 32 additions & 3 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ def predict(
def fit_predict(
self,
handle_id: str,
dataset: str,
dataset: str | None = None,
horizon: int = 12,
data_handle: str | None = None,
exog_handle: str | None = None,
) -> dict[str, Any]:
"""Convenience method: load data, fit, and predict."""
if dataset and data_handle:
Expand All @@ -206,6 +207,10 @@ def fit_predict(
"'data_handle' (from load_data_source) is required."
),
}

y = None
X = None

if data_handle is not None:
# Use custom loaded data
if data_handle not in self._data_handles:
Expand All @@ -225,6 +230,16 @@ def fit_predict(
y = data_result["data"]
X = data_result.get("exog")

# Override X if exog_handle is provided separately
if exog_handle is not None:
if exog_handle not in self._data_handles:
return {
"success": False,
"error": f"Unknown exog handle: {exog_handle}",
}
exog_info = self._data_handles[exog_handle]
X = exog_info["y"] # Treat the target of the exog handle as X

fh = list(range(1, horizon + 1))

fit_result = self.fit(handle_id, y, X=X, fh=fh)
Expand All @@ -238,6 +253,7 @@ async def fit_predict_async(
handle_id: str,
dataset: str | None = None,
data_handle: str | None = None,
exog_handle: str | None = None,
horizon: int = 12,
job_id: str | None = None,
) -> dict[str, Any]:
Expand All @@ -252,6 +268,7 @@ async def fit_predict_async(
handle_id: Estimator handle
dataset: Demo dataset name
data_handle: Data handle from load_data_source
exog_handle: Optional exogenous data handle
horizon: Forecast horizon
job_id: Optional job ID for tracking (created if not provided)

Expand Down Expand Up @@ -327,9 +344,21 @@ async def fit_predict_async(
return data_result

y = data_result["data"]
X = data_result.get("exog")
X = data_result.get("exog")

fh = list(range(1, horizon + 1))
# Override X if exog_handle is provided separately
if exog_handle is not None:
if exog_handle not in self._data_handles:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Unknown exog handle: {exog_handle}"],
)
return {"success": False, "error": f"Unknown exog handle: {exog_handle}"}
exog_info = self._data_handles[exog_handle]
X = exog_info["y"]

fh = list(range(1, horizon + 1))

# Step 2: Fit model
self._job_manager.update_job(
Expand Down
12 changes: 11 additions & 1 deletion src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ async def list_tools() -> list[Tool]:
"description": "Forecast horizon (default: 12)",
"default": 12,
},
"exog_handle": {
"type": "string",
"description": "Optional handle for exogenous variables (X) from load_data_source",
},
},
"required": ["estimator_handle"],
},
Expand Down Expand Up @@ -348,6 +352,10 @@ async def list_tools() -> list[Tool]:
"description": "Forecast horizon (default: 12)",
"default": 12,
},
"exog_handle": {
"type": "string",
"description": "Optional handle for exogenous variables (X) from load_data_source",
},
},
"required": ["estimator_handle"],
},
Expand Down Expand Up @@ -675,16 +683,18 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
elif name == "fit_predict":
result = fit_predict_tool(
arguments["estimator_handle"],
arguments.get("dataset", ""),
arguments.get("dataset"),
arguments.get("horizon", 12),
data_handle=arguments.get("data_handle"),
exog_handle=arguments.get("exog_handle"),
)

elif name == "fit_predict_async":
result = fit_predict_async_tool(
estimator_handle=arguments["estimator_handle"],
dataset=arguments.get("dataset"),
data_handle=arguments.get("data_handle"),
exog_handle=arguments.get("exog_handle"),
horizon=arguments.get("horizon", 12),
)

Expand Down
21 changes: 15 additions & 6 deletions src/sktime_mcp/tools/fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
Executes complete forecasting workflows.
"""

import asyncio
import logging
from typing import Any

from sktime_mcp.runtime.executor import get_executor
from sktime_mcp.runtime.jobs import get_job_manager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,9 +43,10 @@ def _validate_horizon(horizon: int) -> dict[str, Any]:

def fit_predict_tool(
estimator_handle: str,
dataset: str,
dataset: str | None = None,
horizon: int = 12,
data_handle: str | None = None,
exog_handle: str | None = None,
) -> dict[str, Any]:
"""
Execute a complete fit-predict workflow.
Expand All @@ -53,6 +56,7 @@ def fit_predict_tool(
dataset: Name of demo dataset (e.g., "airline", "sunspots")
horizon: Forecast horizon (default: 12)
data_handle: Optional handle from load_data_source for custom data
exog_handle: Optional handle for exogenous variables (X)

Returns:
Dictionary with:
Expand Down Expand Up @@ -89,7 +93,13 @@ def fit_predict_tool(
),
}
executor = get_executor()
return executor.fit_predict(estimator_handle, dataset, horizon, data_handle=data_handle)
return executor.fit_predict(
estimator_handle,
dataset=dataset,
horizon=horizon,
data_handle=data_handle,
exog_handle=exog_handle,
)


def predict_tool(
Expand Down Expand Up @@ -135,6 +145,7 @@ def fit_predict_async_tool(
estimator_handle: str,
dataset: str | None = None,
data_handle: str | None = None,
exog_handle: str | None = None,
horizon: int = 12,
) -> dict[str, Any]:
"""
Expand All @@ -150,6 +161,7 @@ def fit_predict_async_tool(
estimator_handle: Handle from instantiate_estimator
dataset: Name of demo dataset (e.g., "airline", "sunspots")
data_handle: Handle from load_data_source (e.g., "data_abc123")
exog_handle: Optional handle for exogenous variables (X)
horizon: Forecast horizon (default: 12)

Returns:
Expand Down Expand Up @@ -183,10 +195,6 @@ def fit_predict_async_tool(
),
}

import asyncio

from sktime_mcp.runtime.jobs import get_job_manager

executor = get_executor()
job_manager = get_job_manager()

Expand Down Expand Up @@ -221,6 +229,7 @@ def fit_predict_async_tool(
estimator_handle,
dataset=dataset,
data_handle=data_handle,
exog_handle=exog_handle,
horizon=horizon,
job_id=job_id,
)
Expand Down