diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 094c63b9..172cede1 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -50,6 +50,86 @@ def _get_demo_datasets() -> dict: return _DEMO_DATASETS +def _apply_calendar_frequency_and_gap_fill( + y: pd.Series, + X: pd.DataFrame | None, + changes_made: dict[str, Any], +) -> tuple[pd.Series, pd.DataFrame | None]: + """ + Infer frequency and reindex to a full calendar range when safe. + + Only ``DatetimeIndex`` and ``PeriodIndex`` support this path. Other index + types (e.g. ``RangeIndex``) are left unchanged: inferring a daily frequency + and calling ``pd.date_range`` on non-datetime bounds corrupts or empties + the series while still appearing to succeed. + """ + idx = y.index + + if isinstance(idx, pd.DatetimeIndex): + freq = idx.freq + + if freq is None: + try: + freq = pd.infer_freq(idx) + except (ValueError, TypeError): + freq = None + + if freq is None: + time_diffs = idx.to_series().diff().dropna() + if len(time_diffs) > 0: + most_common_diff = time_diffs.mode()[0] + + if most_common_diff == pd.Timedelta(days=1): + freq = "D" + elif most_common_diff == pd.Timedelta(hours=1): + freq = "h" + elif most_common_diff == pd.Timedelta(minutes=1): + freq = "min" + elif most_common_diff == pd.Timedelta(seconds=1): + freq = "s" + elif most_common_diff == pd.Timedelta(days=7): + freq = "W" + elif most_common_diff.days >= 28 and most_common_diff.days <= 31: + freq = "MS" + else: + freq = "D" + + if freq: + full_range = pd.date_range(start=idx.min(), end=idx.max(), freq=freq) + + n_gaps = len(full_range) - len(y) + + y = y.reindex(full_range) + if X is not None: + X = X.reindex(full_range) + + changes_made["gaps_filled"] = n_gaps + changes_made["frequency_set"] = True + changes_made["frequency"] = freq + + elif isinstance(idx, pd.PeriodIndex): + try: + freq = idx.freq or pd.infer_freq(idx) + except (ValueError, TypeError): + freq = idx.freq + if freq: + full_range = pd.period_range(start=idx.min(), end=idx.max(), freq=freq) + n_gaps = len(full_range) - len(y) + y = y.reindex(full_range) + if X is not None: + X = X.reindex(full_range) + changes_made["gaps_filled"] = n_gaps + changes_made["frequency_set"] = True + changes_made["frequency"] = freq + else: + changes_made["calendar_gap_fill_skipped"] = True + changes_made["calendar_gap_fill_reason"] = ( + f"Unsupported index type {type(idx).__name__} for calendar frequency inference" + ) + + return y, X + + class Executor: """ Execution runtime for sktime estimators. @@ -795,48 +875,9 @@ def format_data_handle( if X is not None: X = X.sort_index() - # 3. Infer and set frequency + # 3. Infer and set frequency (calendar gap-fill only for DatetimeIndex / PeriodIndex) if auto_infer_freq: - freq = y.index.freq - - if freq is None: - # Try to infer - freq = pd.infer_freq(y.index) - - if freq is None: - # Manual inference - time_diffs = y.index.to_series().diff().dropna() - if len(time_diffs) > 0: - most_common_diff = time_diffs.mode()[0] - - if most_common_diff == pd.Timedelta(days=1): - freq = "D" - elif most_common_diff == pd.Timedelta(hours=1): - freq = "h" - elif most_common_diff == pd.Timedelta(minutes=1): - freq = "min" - elif most_common_diff == pd.Timedelta(seconds=1): - freq = "s" - elif most_common_diff == pd.Timedelta(days=7): - freq = "W" - elif most_common_diff.days >= 28 and most_common_diff.days <= 31: - freq = "MS" - else: - freq = "D" - - # Create complete date range - if freq: - full_range = pd.date_range(start=y.index.min(), end=y.index.max(), freq=freq) - - n_gaps = len(full_range) - len(y) - - y = y.reindex(full_range) - if X is not None: - X = X.reindex(full_range) - - changes_made["gaps_filled"] = n_gaps - changes_made["frequency_set"] = True - changes_made["frequency"] = freq + y, X = _apply_calendar_frequency_and_gap_fill(y, X, changes_made) # 4. Fill missing values if fill_missing and y.isna().any(): @@ -862,7 +903,11 @@ def format_data_handle( "metadata": { **data_info["metadata"], "formatted": True, - "frequency": str(y.index.freq) if y.index.freq else changes_made.get("frequency"), + "frequency": ( + str(ix_freq) + if (ix_freq := getattr(y.index, "freq", None)) is not None + else changes_made.get("frequency") + ), "rows": len(y), "start_date": str(y.index.min()), "end_date": str(y.index.max()), diff --git a/tests/test_format_data_handle.py b/tests/test_format_data_handle.py new file mode 100644 index 00000000..7e842f06 --- /dev/null +++ b/tests/test_format_data_handle.py @@ -0,0 +1,81 @@ +"""Tests for calendar-safe formatting (non-datetime indices must not be corrupted).""" + +import sys + +import pandas as pd +import pytest + +sys.path.insert(0, "src") + + +@pytest.mark.parametrize( + "index_factory,expect_skip", + [ + (lambda: pd.RangeIndex(3), True), + (lambda: pd.Index([10, 20, 30], dtype="int64"), True), + ], +) +def test_calendar_gap_fill_skipped_for_non_datetime_index(index_factory, expect_skip): + from sktime_mcp.runtime.executor import _apply_calendar_frequency_and_gap_fill + + y = pd.Series([1.0, 2.0, 3.0], index=index_factory()) + changes_made = { + "frequency_set": False, + "duplicates_removed": 0, + "missing_filled": 0, + "gaps_filled": 0, + } + + y_out, X_out = _apply_calendar_frequency_and_gap_fill(y, None, changes_made) + + assert X_out is None + assert len(y_out) == 3 + assert list(y_out.values) == pytest.approx([1.0, 2.0, 3.0]) + if expect_skip: + assert changes_made.get("calendar_gap_fill_skipped") is True + assert "Unsupported index type" in changes_made.get("calendar_gap_fill_reason", "") + + +def test_datetime_index_gap_fill_still_runs(): + from sktime_mcp.runtime.executor import _apply_calendar_frequency_and_gap_fill + + idx = pd.to_datetime(["2020-01-01", "2020-01-03"]) + y = pd.Series([1.0, 3.0], index=idx) + changes_made = { + "frequency_set": False, + "duplicates_removed": 0, + "missing_filled": 0, + "gaps_filled": 0, + } + + y_out, X_out = _apply_calendar_frequency_and_gap_fill(y, None, changes_made) + + assert X_out is None + assert isinstance(y_out.index, pd.DatetimeIndex) + assert changes_made.get("calendar_gap_fill_skipped") is not True + assert len(y_out) >= 2 + + +def test_format_data_handle_range_index_via_executor(): + """format_data_handle must not destroy RangeIndex series (no __init__ → no registry).""" + from sktime_mcp.runtime.executor import Executor + + ex = Executor.__new__(Executor) + ex._data_handles = {} + + h = "data_testfmt1" + ex._data_handles[h] = { + "y": pd.Series([10.0, 20.0, 30.0]), + "X": None, + "metadata": {"source": "test"}, + "validation": {}, + "config": {}, + } + + result = Executor.format_data_handle(ex, h, auto_infer_freq=True) + assert result["success"] + new_h = result["data_handle"] + y_new = ex._data_handles[new_h]["y"] + assert len(y_new) == 3 + assert list(y_new.values) == pytest.approx([10.0, 20.0, 30.0]) + assert result["changes_made"].get("calendar_gap_fill_skipped") is True