diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 4c4f46eb..552e43d9 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -32,6 +32,28 @@ } +def _merge_adapter_validation_warnings( + validation_report: dict[str, Any], + metadata: dict[str, Any], +) -> dict[str, Any]: + """Merge warnings added during adapter conversion into validation output.""" + metadata_validation = metadata.get("validation") + if not isinstance(metadata_validation, dict): + return validation_report + + metadata_warnings = metadata_validation.get("warnings", []) + if not metadata_warnings: + return validation_report + + merged = validation_report.copy() + existing_warnings = list(merged.get("warnings", [])) + for warning in metadata_warnings: + if warning not in existing_warnings: + existing_warnings.append(warning) + merged["warnings"] = existing_warnings + return merged + + class Executor: """ Execution runtime for sktime estimators. @@ -497,6 +519,7 @@ def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]: # Update metadata to reflect the target and used columns metadata = adapter.get_metadata().copy() + validation_report = _merge_adapter_validation_warnings(validation_report, metadata) metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] if X is not None: metadata["exog_columns"] = list(X.columns) @@ -620,6 +643,7 @@ async def load_data_source_async( y, X = adapter.to_sktime_format(data) metadata = adapter.get_metadata().copy() + validation_report = _merge_adapter_validation_warnings(validation_report, metadata) metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"] if X is not None: metadata["exog_columns"] = list(X.columns) diff --git a/tests/test_data_validation_warnings.py b/tests/test_data_validation_warnings.py new file mode 100644 index 00000000..336aec51 --- /dev/null +++ b/tests/test_data_validation_warnings.py @@ -0,0 +1,38 @@ +"""Tests for data validation warning propagation.""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from sktime_mcp.runtime.executor import Executor + + +def _ambiguous_target_config(): + return { + "type": "pandas", + "data": { + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "value": [1, 2, 3], + }, + } + + +def test_load_data_source_propagates_default_target_warning(): + """Default target warnings should appear in the top-level validation result.""" + result = Executor().load_data_source(_ambiguous_target_config()) + + assert result["success"] is True + warnings = result["validation"]["warnings"] + assert any("Target column not specified" in warning for warning in warnings) + + +def test_load_data_source_async_propagates_default_target_warning(): + """Async load should expose the same default-target warning as sync load.""" + result = asyncio.run(Executor().load_data_source_async(_ambiguous_target_config())) + + assert result["success"] is True + warnings = result["validation"]["warnings"] + assert any("Target column not specified" in warning for warning in warnings) +