diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 0dc2673f..a1935e80 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -131,12 +131,16 @@ def fit( return {"success": False, "error": f"Handle not found: {handle_id}"} try: - if fh is not None: - instance.fit(y, X=X, fh=fh) - elif X is not None: - instance.fit(y, X=X) + is_classifier = getattr(instance, "_estimator_type", None) == "classifier" + if is_classifier: + instance.fit(X, y) else: - instance.fit(y) + if fh is not None: + instance.fit(y, X=X, fh=fh) + elif X is not None: + instance.fit(y, X=X) + else: + instance.fit(y) self._handle_manager.mark_fitted(handle_id) return {"success": True, "handle": handle_id, "fitted": True} @@ -159,10 +163,14 @@ def predict( return {"success": False, "error": "Estimator not fitted"} try: - if fh is None: - fh = list(range(1, 13)) + is_classifier = getattr(instance, "_estimator_type", None) == "classifier" + if is_classifier: + predictions = instance.predict(X) + else: + if fh is None: + fh = list(range(1, 13)) - predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh) + predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh) if isinstance(predictions, pd.Series): # Convert index to string to avoid JSON serialization issues with Period/DatetimeIndex