From 408308cd5d5fdaaf493c8a0f1130ca1791cd8b1c Mon Sep 17 00:00:00 2001 From: Abhishek Date: Sat, 4 Apr 2026 13:04:54 +0530 Subject: [PATCH 1/2] [ENH] Enable background jobs for custom data handles (#7) --- src/sktime_mcp/runtime/executor.py | 80 ++++++++++++------ src/sktime_mcp/server.py | 20 +++-- src/sktime_mcp/tools/fit_predict.py | 55 ++++++++---- tests/test_async_custom_data.py | 124 ++++++++++++++++++++++++++++ 4 files changed, 234 insertions(+), 45 deletions(-) create mode 100644 tests/test_async_custom_data.py diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 35a2ab21..bafa2d83 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -222,19 +222,22 @@ def fit_predict( async def fit_predict_async( self, handle_id: str, - dataset: str, + dataset: Optional[str] = None, + data_handle: Optional[str] = None, horizon: int = 12, job_id: Optional[str] = None, ) -> dict[str, Any]: """ Async version of fit_predict with job tracking. - This method runs the training in the background without blocking the MCP server. - Progress is tracked via the JobManager. + Runs the training in the background without blocking the MCP server. + Accepts either a demo dataset name or a data handle from + load_data_source. Args: handle_id: Estimator handle - dataset: Dataset name + dataset: Demo dataset name + data_handle: Data handle from load_data_source horizon: Forecast horizon job_id: Optional job ID for tracking (created if not provided) @@ -249,47 +252,79 @@ async def fit_predict_async( logger.warning(f"Could not get estimator name: {e}") estimator_name = "Unknown" + source_name = dataset if dataset else data_handle + # Create job if not provided if job_id is None: job_id = self._job_manager.create_job( job_type="fit_predict", estimator_handle=handle_id, estimator_name=estimator_name, - dataset_name=dataset, + dataset_name=source_name, horizon=horizon, - total_steps=3, # load data, fit, predict + total_steps=3, ) try: # Update status to RUNNING self._job_manager.update_job(job_id, status=JobStatus.RUNNING) - # Step 1: Load dataset - self._job_manager.update_job( - job_id, completed_steps=0, current_step=f"Loading dataset '{dataset}'..." - ) - await asyncio.sleep(0.01) # Yield control to event loop + # Step 1: Load data + if data_handle: + # Use custom data from a loaded handle + self._job_manager.update_job( + job_id, + completed_steps=0, + current_step=f"Loading data from handle '{data_handle}'...", + ) + await asyncio.sleep(0.01) - data_result = self.load_dataset(dataset) - if not data_result["success"]: + if data_handle not in self._data_handles: + self._job_manager.update_job( + job_id, + status=JobStatus.FAILED, + errors=[f"Unknown data handle: {data_handle}"], + ) + return { + "success": False, + "error": f"Unknown data handle: {data_handle}", + "available_handles": list(self._data_handles.keys()), + } + + data_info = self._data_handles[data_handle] + y = data_info["y"] + X = data_info.get("X") + else: + # Use built-in demo dataset self._job_manager.update_job( job_id, - status=JobStatus.FAILED, - errors=[f"Failed to load dataset: {data_result.get('error')}"], + completed_steps=0, + current_step=f"Loading dataset '{dataset}'...", ) - return data_result + await asyncio.sleep(0.01) + + data_result = self.load_dataset(dataset) + if not data_result["success"]: + self._job_manager.update_job( + job_id, + status=JobStatus.FAILED, + errors=[f"Failed to load dataset: {data_result.get('error')}"], + ) + return data_result + + y = data_result["data"] + X = data_result.get("exog") - y = data_result["data"] - X = data_result.get("exog") fh = list(range(1, horizon + 1)) # Step 2: Fit model self._job_manager.update_job( - job_id, completed_steps=1, current_step=f"Fitting {estimator_name} on {dataset}..." + job_id, + completed_steps=1, + current_step=f"Fitting {estimator_name} on {source_name}...", ) - await asyncio.sleep(0.01) # Yield control + await asyncio.sleep(0.01) - # Run fit in executor to avoid blocking loop = asyncio.get_event_loop() fit_result = await loop.run_in_executor( None, lambda: self.fit(handle_id, y, X=X, fh=fh) @@ -309,9 +344,8 @@ async def fit_predict_async( completed_steps=2, current_step=f"Generating predictions (horizon={horizon})...", ) - await asyncio.sleep(0.01) # Yield control + await asyncio.sleep(0.01) - # Run predict in executor predict_result = await loop.run_in_executor( None, lambda: self.predict(handle_id, fh=fh, X=X) ) diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index e377ff2f..34f7b54c 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -276,8 +276,9 @@ async def list_tools() -> list[Tool]: Tool( name="fit_predict_async", description=( - "Fit an estimator on a dataset and generate predictions " - "(non-blocking background job). Returns a job_id." + "Fit an estimator and generate predictions in the background. " + "Provide exactly ONE of 'dataset' (built-in demo name) " + "or 'data_handle' (from load_data_source)." ), inputSchema={ "type": "object", @@ -288,7 +289,11 @@ async def list_tools() -> list[Tool]: }, "dataset": { "type": "string", - "description": "Dataset name: airline, sunspots, lynx, etc.", + "description": "Demo dataset name: airline, sunspots, lynx, etc.", + }, + "data_handle": { + "type": "string", + "description": "Data handle from load_data_source (e.g. 'data_abc123')", }, "horizon": { "type": "integer", @@ -296,7 +301,7 @@ async def list_tools() -> list[Tool]: "default": 12, }, }, - "required": ["estimator_handle", "dataset"], + "required": ["estimator_handle"], }, ), Tool( @@ -635,9 +640,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: elif name == "fit_predict_async": result = fit_predict_async_tool( - arguments["estimator_handle"], - arguments["dataset"], - arguments.get("horizon", 12), + estimator_handle=arguments["estimator_handle"], + dataset=arguments.get("dataset"), + data_handle=arguments.get("data_handle"), + horizon=arguments.get("horizon", 12), ) elif name == "fit_predict_with_data": diff --git a/src/sktime_mcp/tools/fit_predict.py b/src/sktime_mcp/tools/fit_predict.py index de9c1b5f..6663d409 100644 --- a/src/sktime_mcp/tools/fit_predict.py +++ b/src/sktime_mcp/tools/fit_predict.py @@ -4,7 +4,6 @@ Executes complete forecasting workflows. """ -import asyncio import logging from typing import Any, Optional @@ -107,18 +106,23 @@ def list_datasets_tool() -> dict[str, Any]: def fit_predict_async_tool( estimator_handle: str, - dataset: str, + dataset: Optional[str] = None, + data_handle: Optional[str] = None, horizon: int = 12, ) -> dict[str, Any]: """ Execute a fit-predict workflow in the background (non-blocking). - This tool schedules the training as a background job and returns immediately + Schedules the training as a background job and returns immediately with a job_id. Use check_job_status to monitor progress. + Accepts either a demo dataset name or a data handle from + load_data_source -- exactly one must be provided. + Args: estimator_handle: Handle from instantiate_estimator dataset: Name of demo dataset (e.g., "airline", "sunspots") + data_handle: Handle from load_data_source (e.g., "data_abc123") horizon: Forecast horizon (default: 12) Returns: @@ -128,13 +132,25 @@ def fit_predict_async_tool( - message: Information about the job Example: - >>> fit_predict_async_tool("est_abc123", "airline", horizon=12) - { - "success": True, - "job_id": "abc-123-def-456", - "message": "Training job started. Use check_job_status to monitor progress." - } + >>> fit_predict_async_tool("est_abc123", dataset="airline", horizon=12) + >>> fit_predict_async_tool("est_abc123", data_handle="data_xyz", horizon=5) """ + if dataset and data_handle: + return { + "success": False, + "error": "Provide either 'dataset' or 'data_handle', not both.", + } + + if not dataset and not data_handle: + return { + "success": False, + "error": ( + "Either 'dataset' (e.g. 'airline') or " + "'data_handle' (from load_data_source) is required." + ), + } + + import asyncio from sktime_mcp.runtime.jobs import get_job_manager @@ -149,12 +165,14 @@ def fit_predict_async_tool( logger.warning(f"Could not get estimator name: {e}") estimator_name = "Unknown" + source_name = dataset if dataset else data_handle + # Create job job_id = job_manager.create_job( job_type="fit_predict", estimator_handle=estimator_handle, estimator_name=estimator_name, - dataset_name=dataset, + dataset_name=source_name, horizon=horizon, total_steps=3, ) @@ -163,19 +181,26 @@ def fit_predict_async_tool( try: loop = asyncio.get_event_loop() except RuntimeError: - # No event loop in current thread, create one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - # Schedule the coroutine (non-blocking!) - coro = executor.fit_predict_async(estimator_handle, dataset, horizon, job_id) + coro = executor.fit_predict_async( + estimator_handle, + dataset=dataset, + data_handle=data_handle, + horizon=horizon, + job_id=job_id, + ) asyncio.run_coroutine_threadsafe(coro, loop) return { "success": True, "job_id": job_id, - "message": f"Training job started for {estimator_name} on {dataset}. Use check_job_status('{job_id}') to monitor progress.", + "message": ( + f"Training job started for {estimator_name} on {source_name}. " + f"Use check_job_status('{job_id}') to monitor progress." + ), "estimator": estimator_name, - "dataset": dataset, + "data_source": source_name, "horizon": horizon, } diff --git a/tests/test_async_custom_data.py b/tests/test_async_custom_data.py new file mode 100644 index 00000000..f928c0cd --- /dev/null +++ b/tests/test_async_custom_data.py @@ -0,0 +1,124 @@ +""" +Tests for async fit_predict with custom data handles. +""" + +import sys + +import pytest + +sys.path.insert(0, "src") + + +class TestAsyncCustomData: + """Tests for fit_predict_async_tool with data_handle support.""" + + def _get_estimator_handle(self): + """Create a NaiveForecaster handle for reuse.""" + from sktime_mcp.runtime.executor import get_executor + + executor = get_executor() + result = executor.instantiate("NaiveForecaster", {"strategy": "last"}) + assert result["success"], f"Failed to instantiate: {result}" + return result["handle"] + + def _load_custom_data(self): + """Load custom data and return the data handle.""" + import pandas as pd + + from sktime_mcp.runtime.executor import get_executor + + executor = get_executor() + config = { + "type": "pandas", + "data": { + "date": pd.date_range("2020-01-01", periods=50, freq="D").tolist(), + "sales": [100 + i for i in range(50)], + }, + "time_column": "date", + "target_column": "sales", + } + result = executor.load_data_source(config) + assert result["success"], f"Data load failed: {result}" + return result["data_handle"] + + def test_async_with_dataset(self): + """Async with a demo dataset should return a job_id.""" + from sktime_mcp.tools.fit_predict import fit_predict_async_tool + + handle = self._get_estimator_handle() + result = fit_predict_async_tool( + estimator_handle=handle, + dataset="airline", + horizon=3, + ) + + assert result["success"], f"Expected success, got: {result}" + assert "job_id" in result + assert result["data_source"] == "airline" + + def test_async_with_data_handle(self): + """Async with a custom data handle should return a job_id.""" + from sktime_mcp.tools.fit_predict import fit_predict_async_tool + + handle = self._get_estimator_handle() + data_handle = self._load_custom_data() + + result = fit_predict_async_tool( + estimator_handle=handle, + data_handle=data_handle, + horizon=5, + ) + + assert result["success"], f"Expected success, got: {result}" + assert "job_id" in result + assert result["data_source"] == data_handle + + def test_async_both_provided_error(self): + """Providing both dataset and data_handle should fail.""" + from sktime_mcp.tools.fit_predict import fit_predict_async_tool + + handle = self._get_estimator_handle() + result = fit_predict_async_tool( + estimator_handle=handle, + dataset="airline", + data_handle="data_fake123", + horizon=3, + ) + + assert not result["success"] + assert "error" in result + assert "not both" in result["error"].lower() + + def test_async_neither_provided_error(self): + """Omitting both dataset and data_handle should fail.""" + from sktime_mcp.tools.fit_predict import fit_predict_async_tool + + handle = self._get_estimator_handle() + result = fit_predict_async_tool( + estimator_handle=handle, + horizon=3, + ) + + assert not result["success"] + assert "error" in result + + def test_async_invalid_data_handle(self): + """An invalid data_handle should fail at the executor level.""" + from sktime_mcp.tools.fit_predict import fit_predict_async_tool + + handle = self._get_estimator_handle() + result = fit_predict_async_tool( + estimator_handle=handle, + data_handle="data_nonexistent", + horizon=3, + ) + + # The tool succeeds in scheduling the job (returns job_id), + # but the actual failure happens async inside the executor. + # So we just verify the job was created successfully. + assert result["success"] + assert "job_id" in result + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From b3c3b489d3c6043d45459671ea83f04cd90adf17 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Sat, 4 Apr 2026 13:30:48 +0530 Subject: [PATCH 2/2] fix: update test_async_fit_predict to use keyword args --- tests/test_background_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_background_jobs.py b/tests/test_background_jobs.py index 60d353e2..6d3d9813 100644 --- a/tests/test_background_jobs.py +++ b/tests/test_background_jobs.py @@ -147,7 +147,7 @@ async def test_async_fit_predict(): print(f"✓ Created job: {job_id}") # Run async fit_predict - result = await executor.fit_predict_async(handle, "airline", 12, job_id) + result = await executor.fit_predict_async(handle, dataset="airline", horizon=12, job_id=job_id) # Check result assert result["success"]