Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ def predict(
def fit_predict(
self,
handle_id: str,
dataset: str,
dataset: str | None = None,
horizon: int = 12,
data_handle: str | None = None,
exog_handle: str | None = None,
) -> dict[str, Any]:
"""Convenience method: load data, fit, and predict."""
if dataset and data_handle:
Expand All @@ -206,6 +207,10 @@ def fit_predict(
"'data_handle' (from load_data_source) is required."
),
}

y = None
X = None

if data_handle is not None:
# Use custom loaded data
if data_handle not in self._data_handles:
Expand All @@ -224,6 +229,19 @@ def fit_predict(
return data_result
y = data_result["data"]
X = data_result.get("exog")

if exog_handle is not None:
if exog_handle not in self._data_handles:
return {
"success": False,
"error": f"Unknown exog_handle: {exog_handle}",
"available_handles": list(self._data_handles.keys()),
}
exog_info = self._data_handles[exog_handle]
# Exogenous covariates might be loaded as 'y' if no target_column was specified
X = exog_info.get("X")
if X is None or (hasattr(X, "empty") and X.empty):
X = exog_info.get("y")

fh = list(range(1, horizon + 1))

Expand All @@ -238,6 +256,7 @@ async def fit_predict_async(
handle_id: str,
dataset: str | None = None,
data_handle: str | None = None,
exog_handle: str | None = None,
horizon: int = 12,
job_id: str | None = None,
) -> dict[str, Any]:
Expand All @@ -252,6 +271,7 @@ async def fit_predict_async(
handle_id: Estimator handle
dataset: Demo dataset name
data_handle: Data handle from load_data_source
exog_handle: Handle for exogenous variables (covariates)
horizon: Forecast horizon
job_id: Optional job ID for tracking (created if not provided)

Expand Down Expand Up @@ -329,6 +349,20 @@ async def fit_predict_async(
y = data_result["data"]
X = data_result.get("exog")

# Resolve exog_handle if provided - overrides any X from dataset
if exog_handle:
if exog_handle not in self._data_handles:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Unknown exog_handle: {exog_handle}"],
)
return {"success": False, "job_id": job_id}
exog_info = self._data_handles[exog_handle]
X = exog_info.get("X")
if X is None or (hasattr(X, "empty") and X.empty):
X = exog_info.get("y")

fh = list(range(1, horizon + 1))

# Step 2: Fit model
Expand All @@ -339,6 +373,7 @@ async def fit_predict_async(
)
await asyncio.sleep(0.01)


loop = asyncio.get_event_loop()
fit_result = await loop.run_in_executor(
None, lambda: self.fit(handle_id, y, X=X, fh=fh)
Expand Down
17 changes: 17 additions & 0 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ async def list_tools() -> list[Tool]:
"of dataset for custom data)"
),
},
"exog_handle": {
"type": "string",
"description": (
"Optional handle for exogenous variables (covariates) "
"loaded via load_data_source. Passed as X to the estimator "
"alongside the target y. Required for multivariate forecasting."
),
},
"horizon": {
"type": "integer",
"description": "Forecast horizon (default: 12)",
Expand Down Expand Up @@ -343,6 +351,13 @@ async def list_tools() -> list[Tool]:
"type": "string",
"description": "Data handle from load_data_source (e.g. 'data_abc123')",
},
"exog_handle": {
"type": "string",
"description": (
"Optional handle for exogenous variables (covariates) "
"loaded via load_data_source. Passed as X to the estimator."
),
},
"horizon": {
"type": "integer",
"description": "Forecast horizon (default: 12)",
Expand Down Expand Up @@ -678,13 +693,15 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
arguments.get("dataset", ""),
arguments.get("horizon", 12),
data_handle=arguments.get("data_handle"),
exog_handle=arguments.get("exog_handle"),
)

elif name == "fit_predict_async":
result = fit_predict_async_tool(
estimator_handle=arguments["estimator_handle"],
dataset=arguments.get("dataset"),
data_handle=arguments.get("data_handle"),
exog_handle=arguments.get("exog_handle"),
horizon=arguments.get("horizon", 12),
)

Expand Down
17 changes: 15 additions & 2 deletions src/sktime_mcp/tools/fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def _validate_horizon(horizon: int) -> dict[str, Any]:

def fit_predict_tool(
estimator_handle: str,
dataset: str,
dataset: str | None = None,
horizon: int = 12,
data_handle: str | None = None,
exog_handle: str | None = None,
) -> dict[str, Any]:
"""
Execute a complete fit-predict workflow.
Expand All @@ -53,6 +54,8 @@ def fit_predict_tool(
dataset: Name of demo dataset (e.g., "airline", "sunspots")
horizon: Forecast horizon (default: 12)
data_handle: Optional handle from load_data_source for custom data
exog_handle: Optional handle for exogenous variables (covariates),
also from load_data_source. Passed as X to the estimator.

Returns:
Dictionary with:
Expand Down Expand Up @@ -89,7 +92,13 @@ def fit_predict_tool(
),
}
executor = get_executor()
return executor.fit_predict(estimator_handle, dataset, horizon, data_handle=data_handle)
return executor.fit_predict(
estimator_handle,
dataset,
horizon,
data_handle=data_handle,
exog_handle=exog_handle,
)


def predict_tool(
Expand Down Expand Up @@ -135,6 +144,7 @@ def fit_predict_async_tool(
estimator_handle: str,
dataset: str | None = None,
data_handle: str | None = None,
exog_handle: str | None = None,
horizon: int = 12,
) -> dict[str, Any]:
"""
Expand All @@ -150,6 +160,8 @@ def fit_predict_async_tool(
estimator_handle: Handle from instantiate_estimator
dataset: Name of demo dataset (e.g., "airline", "sunspots")
data_handle: Handle from load_data_source (e.g., "data_abc123")
exog_handle: Optional handle for exogenous covariates from load_data_source.
Passed as X to the estimator alongside the target y.
horizon: Forecast horizon (default: 12)

Returns:
Expand Down Expand Up @@ -221,6 +233,7 @@ def fit_predict_async_tool(
estimator_handle,
dataset=dataset,
data_handle=data_handle,
exog_handle=exog_handle,
horizon=horizon,
job_id=job_id,
)
Expand Down