Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 87 additions & 42 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,86 @@ def _get_demo_datasets() -> dict:
return _DEMO_DATASETS


def _apply_calendar_frequency_and_gap_fill(
y: pd.Series,
X: pd.DataFrame | None,
changes_made: dict[str, Any],
) -> tuple[pd.Series, pd.DataFrame | None]:
"""
Infer frequency and reindex to a full calendar range when safe.

Only ``DatetimeIndex`` and ``PeriodIndex`` support this path. Other index
types (e.g. ``RangeIndex``) are left unchanged: inferring a daily frequency
and calling ``pd.date_range`` on non-datetime bounds corrupts or empties
the series while still appearing to succeed.
"""
idx = y.index

if isinstance(idx, pd.DatetimeIndex):
freq = idx.freq

if freq is None:
try:
freq = pd.infer_freq(idx)
except (ValueError, TypeError):
freq = None

if freq is None:
time_diffs = idx.to_series().diff().dropna()
if len(time_diffs) > 0:
most_common_diff = time_diffs.mode()[0]

if most_common_diff == pd.Timedelta(days=1):
freq = "D"
elif most_common_diff == pd.Timedelta(hours=1):
freq = "h"
elif most_common_diff == pd.Timedelta(minutes=1):
freq = "min"
elif most_common_diff == pd.Timedelta(seconds=1):
freq = "s"
elif most_common_diff == pd.Timedelta(days=7):
freq = "W"
elif most_common_diff.days >= 28 and most_common_diff.days <= 31:
freq = "MS"
else:
freq = "D"

if freq:
full_range = pd.date_range(start=idx.min(), end=idx.max(), freq=freq)

n_gaps = len(full_range) - len(y)

y = y.reindex(full_range)
if X is not None:
X = X.reindex(full_range)

changes_made["gaps_filled"] = n_gaps
changes_made["frequency_set"] = True
changes_made["frequency"] = freq

elif isinstance(idx, pd.PeriodIndex):
try:
freq = idx.freq or pd.infer_freq(idx)
except (ValueError, TypeError):
freq = idx.freq
if freq:
full_range = pd.period_range(start=idx.min(), end=idx.max(), freq=freq)
n_gaps = len(full_range) - len(y)
y = y.reindex(full_range)
if X is not None:
X = X.reindex(full_range)
changes_made["gaps_filled"] = n_gaps
changes_made["frequency_set"] = True
changes_made["frequency"] = freq
else:
changes_made["calendar_gap_fill_skipped"] = True
changes_made["calendar_gap_fill_reason"] = (
f"Unsupported index type {type(idx).__name__} for calendar frequency inference"
)

return y, X


class Executor:
"""
Execution runtime for sktime estimators.
Expand Down Expand Up @@ -795,48 +875,9 @@ def format_data_handle(
if X is not None:
X = X.sort_index()

# 3. Infer and set frequency
# 3. Infer and set frequency (calendar gap-fill only for DatetimeIndex / PeriodIndex)
if auto_infer_freq:
freq = y.index.freq

if freq is None:
# Try to infer
freq = pd.infer_freq(y.index)

if freq is None:
# Manual inference
time_diffs = y.index.to_series().diff().dropna()
if len(time_diffs) > 0:
most_common_diff = time_diffs.mode()[0]

if most_common_diff == pd.Timedelta(days=1):
freq = "D"
elif most_common_diff == pd.Timedelta(hours=1):
freq = "h"
elif most_common_diff == pd.Timedelta(minutes=1):
freq = "min"
elif most_common_diff == pd.Timedelta(seconds=1):
freq = "s"
elif most_common_diff == pd.Timedelta(days=7):
freq = "W"
elif most_common_diff.days >= 28 and most_common_diff.days <= 31:
freq = "MS"
else:
freq = "D"

# Create complete date range
if freq:
full_range = pd.date_range(start=y.index.min(), end=y.index.max(), freq=freq)

n_gaps = len(full_range) - len(y)

y = y.reindex(full_range)
if X is not None:
X = X.reindex(full_range)

changes_made["gaps_filled"] = n_gaps
changes_made["frequency_set"] = True
changes_made["frequency"] = freq
y, X = _apply_calendar_frequency_and_gap_fill(y, X, changes_made)

# 4. Fill missing values
if fill_missing and y.isna().any():
Expand All @@ -862,7 +903,11 @@ def format_data_handle(
"metadata": {
**data_info["metadata"],
"formatted": True,
"frequency": str(y.index.freq) if y.index.freq else changes_made.get("frequency"),
"frequency": (
str(ix_freq)
if (ix_freq := getattr(y.index, "freq", None)) is not None
else changes_made.get("frequency")
),
"rows": len(y),
"start_date": str(y.index.min()),
"end_date": str(y.index.max()),
Expand Down
81 changes: 81 additions & 0 deletions tests/test_format_data_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Tests for calendar-safe formatting (non-datetime indices must not be corrupted)."""

import sys

import pandas as pd
import pytest

sys.path.insert(0, "src")


@pytest.mark.parametrize(
"index_factory,expect_skip",
[
(lambda: pd.RangeIndex(3), True),
(lambda: pd.Index([10, 20, 30], dtype="int64"), True),
],
)
def test_calendar_gap_fill_skipped_for_non_datetime_index(index_factory, expect_skip):
from sktime_mcp.runtime.executor import _apply_calendar_frequency_and_gap_fill

y = pd.Series([1.0, 2.0, 3.0], index=index_factory())
changes_made = {
"frequency_set": False,
"duplicates_removed": 0,
"missing_filled": 0,
"gaps_filled": 0,
}

y_out, X_out = _apply_calendar_frequency_and_gap_fill(y, None, changes_made)

assert X_out is None
assert len(y_out) == 3
assert list(y_out.values) == pytest.approx([1.0, 2.0, 3.0])
if expect_skip:
assert changes_made.get("calendar_gap_fill_skipped") is True
assert "Unsupported index type" in changes_made.get("calendar_gap_fill_reason", "")


def test_datetime_index_gap_fill_still_runs():
from sktime_mcp.runtime.executor import _apply_calendar_frequency_and_gap_fill

idx = pd.to_datetime(["2020-01-01", "2020-01-03"])
y = pd.Series([1.0, 3.0], index=idx)
changes_made = {
"frequency_set": False,
"duplicates_removed": 0,
"missing_filled": 0,
"gaps_filled": 0,
}

y_out, X_out = _apply_calendar_frequency_and_gap_fill(y, None, changes_made)

assert X_out is None
assert isinstance(y_out.index, pd.DatetimeIndex)
assert changes_made.get("calendar_gap_fill_skipped") is not True
assert len(y_out) >= 2


def test_format_data_handle_range_index_via_executor():
"""format_data_handle must not destroy RangeIndex series (no __init__ → no registry)."""
from sktime_mcp.runtime.executor import Executor

ex = Executor.__new__(Executor)
ex._data_handles = {}

h = "data_testfmt1"
ex._data_handles[h] = {
"y": pd.Series([10.0, 20.0, 30.0]),
"X": None,
"metadata": {"source": "test"},
"validation": {},
"config": {},
}

result = Executor.format_data_handle(ex, h, auto_infer_freq=True)
assert result["success"]
new_h = result["data_handle"]
y_new = ex._data_handles[new_h]["y"]
assert len(y_new) == 3
assert list(y_new.values) == pytest.approx([10.0, 20.0, 30.0])
assert result["changes_made"].get("calendar_gap_fill_skipped") is True