From 51f52925bfa3c925c5a088d1040027ddd5fc812a Mon Sep 17 00:00:00 2001 From: ESoC Contributor Date: Fri, 1 May 2026 14:20:06 +0530 Subject: [PATCH] fix: classification workflows not executable end-to-end (Fixes #405) --- src/sktime_mcp/runtime/executor.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 0dc2673f..a1935e80 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -131,12 +131,16 @@ def fit( return {"success": False, "error": f"Handle not found: {handle_id}"} try: - if fh is not None: - instance.fit(y, X=X, fh=fh) - elif X is not None: - instance.fit(y, X=X) + is_classifier = getattr(instance, "_estimator_type", None) == "classifier" + if is_classifier: + instance.fit(X, y) else: - instance.fit(y) + if fh is not None: + instance.fit(y, X=X, fh=fh) + elif X is not None: + instance.fit(y, X=X) + else: + instance.fit(y) self._handle_manager.mark_fitted(handle_id) return {"success": True, "handle": handle_id, "fitted": True} @@ -159,10 +163,14 @@ def predict( return {"success": False, "error": "Estimator not fitted"} try: - if fh is None: - fh = list(range(1, 13)) + is_classifier = getattr(instance, "_estimator_type", None) == "classifier" + if is_classifier: + predictions = instance.predict(X) + else: + if fh is None: + fh = list(range(1, 13)) - predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh) + predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh) if isinstance(predictions, pd.Series): # Convert index to string to avoid JSON serialization issues with Period/DatetimeIndex