Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,22 @@ def fit_predict(
async def fit_predict_async(
self,
handle_id: str,
dataset: str,
dataset: Optional[str] = None,
data_handle: Optional[str] = None,
horizon: int = 12,
job_id: Optional[str] = None,
) -> dict[str, Any]:
"""
Async version of fit_predict with job tracking.

This method runs the training in the background without blocking the MCP server.
Progress is tracked via the JobManager.
Runs the training in the background without blocking the MCP server.
Accepts either a demo dataset name or a data handle from
load_data_source.

Args:
handle_id: Estimator handle
dataset: Dataset name
dataset: Demo dataset name
data_handle: Data handle from load_data_source
horizon: Forecast horizon
job_id: Optional job ID for tracking (created if not provided)

Expand All @@ -254,47 +257,79 @@ async def fit_predict_async(
logger.warning(f"Could not get estimator name: {e}")
estimator_name = "Unknown"

source_name = dataset if dataset else data_handle

# Create job if not provided
if job_id is None:
job_id = self._job_manager.create_job(
job_type="fit_predict",
estimator_handle=handle_id,
estimator_name=estimator_name,
dataset_name=dataset,
dataset_name=source_name,
horizon=horizon,
total_steps=3, # load data, fit, predict
total_steps=3,
)

try:
# Update status to RUNNING
self._job_manager.update_job(job_id, status=JobStatus.RUNNING)

# Step 1: Load dataset
self._job_manager.update_job(
job_id, completed_steps=0, current_step=f"Loading dataset '{dataset}'..."
)
await asyncio.sleep(0.01) # Yield control to event loop
# Step 1: Load data
if data_handle:
# Use custom data from a loaded handle
self._job_manager.update_job(
job_id,
completed_steps=0,
current_step=f"Loading data from handle '{data_handle}'...",
)
await asyncio.sleep(0.01)

data_result = self.load_dataset(dataset)
if not data_result["success"]:
if data_handle not in self._data_handles:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Unknown data handle: {data_handle}"],
)
return {
"success": False,
"error": f"Unknown data handle: {data_handle}",
"available_handles": list(self._data_handles.keys()),
}

data_info = self._data_handles[data_handle]
y = data_info["y"]
X = data_info.get("X")
else:
# Use built-in demo dataset
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Failed to load dataset: {data_result.get('error')}"],
completed_steps=0,
current_step=f"Loading dataset '{dataset}'...",
)
return data_result
await asyncio.sleep(0.01)

data_result = self.load_dataset(dataset)
if not data_result["success"]:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Failed to load dataset: {data_result.get('error')}"],
)
return data_result

y = data_result["data"]
X = data_result.get("exog")

y = data_result["data"]
X = data_result.get("exog")
fh = list(range(1, horizon + 1))

# Step 2: Fit model
self._job_manager.update_job(
job_id, completed_steps=1, current_step=f"Fitting {estimator_name} on {dataset}..."
job_id,
completed_steps=1,
current_step=f"Fitting {estimator_name} on {source_name}...",
)
await asyncio.sleep(0.01) # Yield control
await asyncio.sleep(0.01)

# Run fit in executor to avoid blocking
loop = asyncio.get_event_loop()
fit_result = await loop.run_in_executor(
None, lambda: self.fit(handle_id, y, X=X, fh=fh)
Expand All @@ -314,9 +349,8 @@ async def fit_predict_async(
completed_steps=2,
current_step=f"Generating predictions (horizon={horizon})...",
)
await asyncio.sleep(0.01) # Yield control
await asyncio.sleep(0.01)

# Run predict in executor
predict_result = await loop.run_in_executor(
None, lambda: self.predict(handle_id, fh=fh, X=X)
)
Expand Down
20 changes: 13 additions & 7 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ async def list_tools() -> list[Tool]:
Tool(
name="fit_predict_async",
description=(
"Fit an estimator on a dataset and generate predictions "
"(non-blocking background job). Returns a job_id."
"Fit an estimator and generate predictions in the background. "
"Provide exactly ONE of 'dataset' (built-in demo name) "
"or 'data_handle' (from load_data_source)."
),
inputSchema={
"type": "object",
Expand All @@ -276,15 +277,19 @@ async def list_tools() -> list[Tool]:
},
"dataset": {
"type": "string",
"description": "Dataset name: airline, sunspots, lynx, etc.",
"description": "Demo dataset name: airline, sunspots, lynx, etc.",
},
"data_handle": {
"type": "string",
"description": "Data handle from load_data_source (e.g. 'data_abc123')",
},
"horizon": {
"type": "integer",
"description": "Forecast horizon (default: 12)",
"default": 12,
},
},
"required": ["estimator_handle", "dataset"],
"required": ["estimator_handle"],
},
),
Tool(
Expand Down Expand Up @@ -618,9 +623,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:

elif name == "fit_predict_async":
result = fit_predict_async_tool(
arguments["estimator_handle"],
arguments["dataset"],
arguments.get("horizon", 12),
estimator_handle=arguments["estimator_handle"],
dataset=arguments.get("dataset"),
data_handle=arguments.get("data_handle"),
horizon=arguments.get("horizon", 12),
)

elif name == "evaluate_estimator":
Expand Down
55 changes: 40 additions & 15 deletions src/sktime_mcp/tools/fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Executes complete forecasting workflows.
"""

import asyncio
import logging
from typing import Any, Optional

Expand Down Expand Up @@ -112,18 +111,23 @@ def list_datasets_tool() -> dict[str, Any]:

def fit_predict_async_tool(
estimator_handle: str,
dataset: str,
dataset: Optional[str] = None,
data_handle: Optional[str] = None,
horizon: int = 12,
) -> dict[str, Any]:
"""
Execute a fit-predict workflow in the background (non-blocking).

This tool schedules the training as a background job and returns immediately
Schedules the training as a background job and returns immediately
with a job_id. Use check_job_status to monitor progress.

Accepts either a demo dataset name or a data handle from
load_data_source -- exactly one must be provided.

Args:
estimator_handle: Handle from instantiate_estimator
dataset: Name of demo dataset (e.g., "airline", "sunspots")
data_handle: Handle from load_data_source (e.g., "data_abc123")
horizon: Forecast horizon (default: 12)

Returns:
Expand All @@ -133,13 +137,25 @@ def fit_predict_async_tool(
- message: Information about the job

Example:
>>> fit_predict_async_tool("est_abc123", "airline", horizon=12)
{
"success": True,
"job_id": "abc-123-def-456",
"message": "Training job started. Use check_job_status to monitor progress."
}
>>> fit_predict_async_tool("est_abc123", dataset="airline", horizon=12)
>>> fit_predict_async_tool("est_abc123", data_handle="data_xyz", horizon=5)
"""
if dataset and data_handle:
return {
"success": False,
"error": "Provide either 'dataset' or 'data_handle', not both.",
}

if not dataset and not data_handle:
return {
"success": False,
"error": (
"Either 'dataset' (e.g. 'airline') or "
"'data_handle' (from load_data_source) is required."
),
}

import asyncio

from sktime_mcp.runtime.jobs import get_job_manager

Expand All @@ -154,12 +170,14 @@ def fit_predict_async_tool(
logger.warning(f"Could not get estimator name: {e}")
estimator_name = "Unknown"

source_name = dataset if dataset else data_handle

# Create job
job_id = job_manager.create_job(
job_type="fit_predict",
estimator_handle=estimator_handle,
estimator_name=estimator_name,
dataset_name=dataset,
dataset_name=source_name,
horizon=horizon,
total_steps=3,
)
Expand All @@ -168,19 +186,26 @@ def fit_predict_async_tool(
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No event loop in current thread, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Schedule the coroutine (non-blocking!)
coro = executor.fit_predict_async(estimator_handle, dataset, horizon, job_id)
coro = executor.fit_predict_async(
estimator_handle,
dataset=dataset,
data_handle=data_handle,
horizon=horizon,
job_id=job_id,
)
asyncio.run_coroutine_threadsafe(coro, loop)

return {
"success": True,
"job_id": job_id,
"message": f"Training job started for {estimator_name} on {dataset}. Use check_job_status('{job_id}') to monitor progress.",
"message": (
f"Training job started for {estimator_name} on {source_name}. "
f"Use check_job_status('{job_id}') to monitor progress."
),
"estimator": estimator_name,
"dataset": dataset,
"data_source": source_name,
"horizon": horizon,
}
Loading
Loading