diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 0dc2673f..dfcb12c1 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -187,9 +187,10 @@ def predict( def fit_predict( self, handle_id: str, - dataset: str, + dataset: str | None = None, horizon: int = 12, data_handle: str | None = None, + exog_handle: str | None = None, ) -> dict[str, Any]: """Convenience method: load data, fit, and predict.""" if dataset and data_handle: @@ -206,6 +207,10 @@ def fit_predict( "'data_handle' (from load_data_source) is required." ), } + + y = None + X = None + if data_handle is not None: # Use custom loaded data if data_handle not in self._data_handles: @@ -224,6 +229,19 @@ def fit_predict( return data_result y = data_result["data"] X = data_result.get("exog") + + if exog_handle is not None: + if exog_handle not in self._data_handles: + return { + "success": False, + "error": f"Unknown exog_handle: {exog_handle}", + "available_handles": list(self._data_handles.keys()), + } + exog_info = self._data_handles[exog_handle] + # Exogenous covariates might be loaded as 'y' if no target_column was specified + X = exog_info.get("X") + if X is None or (hasattr(X, "empty") and X.empty): + X = exog_info.get("y") fh = list(range(1, horizon + 1)) @@ -238,6 +256,7 @@ async def fit_predict_async( handle_id: str, dataset: str | None = None, data_handle: str | None = None, + exog_handle: str | None = None, horizon: int = 12, job_id: str | None = None, ) -> dict[str, Any]: @@ -252,6 +271,7 @@ async def fit_predict_async( handle_id: Estimator handle dataset: Demo dataset name data_handle: Data handle from load_data_source + exog_handle: Handle for exogenous variables (covariates) horizon: Forecast horizon job_id: Optional job ID for tracking (created if not provided) @@ -329,6 +349,20 @@ async def fit_predict_async( y = data_result["data"] X = data_result.get("exog") + # Resolve exog_handle if provided - overrides any X from dataset + if exog_handle: + if exog_handle not in self._data_handles: + self._job_manager.update_job( + job_id, + status=JobStatus.FAILED, + errors=[f"Unknown exog_handle: {exog_handle}"], + ) + return {"success": False, "job_id": job_id} + exog_info = self._data_handles[exog_handle] + X = exog_info.get("X") + if X is None or (hasattr(X, "empty") and X.empty): + X = exog_info.get("y") + fh = list(range(1, horizon + 1)) # Step 2: Fit model @@ -339,6 +373,7 @@ async def fit_predict_async( ) await asyncio.sleep(0.01) + loop = asyncio.get_event_loop() fit_result = await loop.run_in_executor( None, lambda: self.fit(handle_id, y, X=X, fh=fh) diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index edee569a..a930c204 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -312,6 +312,14 @@ async def list_tools() -> list[Tool]: "of dataset for custom data)" ), }, + "exog_handle": { + "type": "string", + "description": ( + "Optional handle for exogenous variables (covariates) " + "loaded via load_data_source. Passed as X to the estimator " + "alongside the target y. Required for multivariate forecasting." + ), + }, "horizon": { "type": "integer", "description": "Forecast horizon (default: 12)", @@ -343,6 +351,13 @@ async def list_tools() -> list[Tool]: "type": "string", "description": "Data handle from load_data_source (e.g. 'data_abc123')", }, + "exog_handle": { + "type": "string", + "description": ( + "Optional handle for exogenous variables (covariates) " + "loaded via load_data_source. Passed as X to the estimator." + ), + }, "horizon": { "type": "integer", "description": "Forecast horizon (default: 12)", @@ -678,6 +693,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: arguments.get("dataset", ""), arguments.get("horizon", 12), data_handle=arguments.get("data_handle"), + exog_handle=arguments.get("exog_handle"), ) elif name == "fit_predict_async": @@ -685,6 +701,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: estimator_handle=arguments["estimator_handle"], dataset=arguments.get("dataset"), data_handle=arguments.get("data_handle"), + exog_handle=arguments.get("exog_handle"), horizon=arguments.get("horizon", 12), ) diff --git a/src/sktime_mcp/tools/fit_predict.py b/src/sktime_mcp/tools/fit_predict.py index ef8f1b46..0af97e02 100644 --- a/src/sktime_mcp/tools/fit_predict.py +++ b/src/sktime_mcp/tools/fit_predict.py @@ -41,9 +41,10 @@ def _validate_horizon(horizon: int) -> dict[str, Any]: def fit_predict_tool( estimator_handle: str, - dataset: str, + dataset: str | None = None, horizon: int = 12, data_handle: str | None = None, + exog_handle: str | None = None, ) -> dict[str, Any]: """ Execute a complete fit-predict workflow. @@ -53,6 +54,8 @@ def fit_predict_tool( dataset: Name of demo dataset (e.g., "airline", "sunspots") horizon: Forecast horizon (default: 12) data_handle: Optional handle from load_data_source for custom data + exog_handle: Optional handle for exogenous variables (covariates), + also from load_data_source. Passed as X to the estimator. Returns: Dictionary with: @@ -89,7 +92,13 @@ def fit_predict_tool( ), } executor = get_executor() - return executor.fit_predict(estimator_handle, dataset, horizon, data_handle=data_handle) + return executor.fit_predict( + estimator_handle, + dataset, + horizon, + data_handle=data_handle, + exog_handle=exog_handle, + ) def predict_tool( @@ -135,6 +144,7 @@ def fit_predict_async_tool( estimator_handle: str, dataset: str | None = None, data_handle: str | None = None, + exog_handle: str | None = None, horizon: int = 12, ) -> dict[str, Any]: """ @@ -150,6 +160,8 @@ def fit_predict_async_tool( estimator_handle: Handle from instantiate_estimator dataset: Name of demo dataset (e.g., "airline", "sunspots") data_handle: Handle from load_data_source (e.g., "data_abc123") + exog_handle: Optional handle for exogenous covariates from load_data_source. + Passed as X to the estimator alongside the target y. horizon: Forecast horizon (default: 12) Returns: @@ -221,6 +233,7 @@ def fit_predict_async_tool( estimator_handle, dataset=dataset, data_handle=data_handle, + exog_handle=exog_handle, horizon=horizon, job_id=job_id, )