diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d213..678e8a89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar +- Add `--dry_run_data` mode to `train_model` to preflight-validate dataloader batches for pipeline-specific temporal alignment and forcing window consistency before creating model/trainer [\#510](https://github.com/mllam/neural-lam/issues/510) @AR10129 + ### Changed - Change the default ensemble-loading behavior in `WeatherDataset` / `WeatherDataModule` to use all ensemble members as independent samples for ensemble datastores (with matching ensemble-member selection for forcing when available); single-member behavior now requires explicitly opting in via `--load_single_member` [\#332](https://github.com/mllam/neural-lam/pull/332) @kshirajahere diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 06cd608a..5fc431ba 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -3,6 +3,7 @@ import random import time from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from typing import Iterable # Third-party # for logging the model: @@ -24,6 +25,111 @@ } +def _validate_preflight_batch( + batch, + split_name: str, + expected_ar_steps: int, + forcing_window_size: int, +): + """Validate one batch for pipeline-specific temporal/forcing invariants.""" + if not isinstance(batch, (tuple, list)) or len(batch) != 4: + raise ValueError( + f"{split_name}: expected batch to be a 4-tuple " + "(init_states, target_states, forcing, target_times)." + ) + + init_states, target_states, forcing, target_times = batch + + if not all(isinstance(t, torch.Tensor) for t in batch): + raise ValueError(f"{split_name}: all batch entries must be tensors.") + + if target_states.shape[1] != expected_ar_steps: + raise ValueError( + f"{split_name}: target_states time dimension must match " + f"expected_ar_steps={expected_ar_steps}, got " + f"{target_states.shape[1]}." + ) + + if forcing.shape[1] != expected_ar_steps: + raise ValueError( + f"{split_name}: forcing time dimension must match " + f"expected_ar_steps={expected_ar_steps}, got " + f"{forcing.shape[1]}." + ) + + if target_times.shape[1] != expected_ar_steps: + raise ValueError( + f"{split_name}: target_times length must match " + f"expected_ar_steps={expected_ar_steps}, got " + f"{target_times.shape[1]}." + ) + + if forcing.shape[-1] % forcing_window_size != 0: + raise ValueError( + f"{split_name}: forcing feature size ({forcing.shape[-1]}) is " + f"not divisible by forcing window size ({forcing_window_size})." + ) + + +def _validate_preflight_loader( + loader: Iterable, + split_name: str, + expected_ar_steps: int, + forcing_window_size: int, +): + """Validate the first batch produced by a dataloader.""" + try: + batch = next(iter(loader)) + except StopIteration as exc: + raise ValueError(f"{split_name}: dataloader is empty.") from exc + + _validate_preflight_batch( + batch=batch, + split_name=split_name, + expected_ar_steps=expected_ar_steps, + forcing_window_size=forcing_window_size, + ) + + +def _run_data_preflight(data_module: WeatherDataModule, args): + """Run one-batch data pipeline checks and fail fast on invalid data.""" + forcing_window_size = ( + args.num_past_forcing_steps + args.num_future_forcing_steps + 1 + ) + + if args.eval: + data_module.setup(stage="test") + _validate_preflight_loader( + loader=data_module.test_dataloader(), + split_name=f"eval_{args.eval}", + expected_ar_steps=args.ar_steps_eval, + forcing_window_size=forcing_window_size, + ) + logger.info( + f"Data preflight passed for eval split '{args.eval}' with " + f"ar_steps={args.ar_steps_eval}." + ) + else: + data_module.setup(stage="fit") + _validate_preflight_loader( + loader=data_module.train_dataloader(), + split_name="train", + expected_ar_steps=args.ar_steps_train, + forcing_window_size=forcing_window_size, + ) + _validate_preflight_loader( + loader=data_module.val_dataloader(), + split_name="val", + expected_ar_steps=args.ar_steps_eval, + forcing_window_size=forcing_window_size, + ) + logger.info( + "Data preflight passed for train/val splits with " + f"ar_steps_train={args.ar_steps_train} and " + f"ar_steps_eval={args.ar_steps_eval}." + ) + + @logger.catch def main(input_args=None): """Main function for training and evaluating models.""" @@ -244,6 +350,14 @@ def main(input_args=None): "ensemble members as independent samples." ), ) + parser.add_argument( + "--dry_run_data", + action="store_true", + help=( + "Validate one batch from each relevant dataloader and exit " + "without creating model/trainer or running train/test." + ), + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -297,6 +411,11 @@ def main(input_args=None): eval_split=args.eval or "test", ) + if args.dry_run_data: + _run_data_preflight(data_module=data_module, args=args) + logger.info("Exiting after successful data preflight.") + return + # Instantiate model + trainer if torch.cuda.is_available(): device_name = "cuda" diff --git a/tests/test_cli.py b/tests/test_cli.py index 1be4db14..9468852c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,9 +1,11 @@ # Standard library from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from types import SimpleNamespace from unittest.mock import MagicMock, patch # Third-party import pytest +import torch # First-party import neural_lam @@ -11,9 +13,8 @@ import neural_lam.train_model -def test_import(): - """This test just ensures that each cli entry-point can be imported for now, - eventually we should test their execution too.""" +def test_cli_entrypoints_importable(): + """Each CLI entry-point can be imported.""" assert neural_lam is not None assert neural_lam.create_graph is not None assert neural_lam.train_model is not None @@ -117,3 +118,109 @@ def test_wandb_id_ignored_with_mlflow_warns(): warning_msg = mock_log.warning.call_args[0][0] assert "--wandb_id is set but logger is" in warning_msg assert "mlflow" in warning_msg + + +def _make_batch(ar_steps, forcing_dim): + """Create a synthetic weather batch with DataLoader-like batch dim.""" + batch_size = 2 + n_grid = 5 + d_state = 3 + + init_states = torch.randn(batch_size, 2, n_grid, d_state) + target_states = torch.randn(batch_size, ar_steps, n_grid, d_state) + forcing = torch.randn(batch_size, ar_steps, n_grid, forcing_dim) + base = torch.arange(ar_steps, dtype=torch.int64) + target_times = torch.stack([base, base + 10], dim=0) + + return init_states, target_states, forcing, target_times + + +def test_dry_run_data_preflight_success_skips_training_setup(): + """--dry_run_data runs preflight checks and exits before trainer setup.""" + + class DummyDataModule: + def __init__(self, *args, **kwargs): + pass + + def setup(self, stage=None): + self.stage = stage + + def train_dataloader(self): + # window size = 1 + 1 + 1 = 3, forcing dim=6 is valid + return [_make_batch(ar_steps=2, forcing_dim=6)] + + def val_dataloader(self): + return [_make_batch(ar_steps=3, forcing_dim=6)] + + def test_dataloader(self): + return [_make_batch(ar_steps=3, forcing_dim=6)] + + with ( + patch( + "neural_lam.train_model.load_config_and_datastore", + return_value=(MagicMock(), MagicMock()), + ), + patch( + "neural_lam.train_model.WeatherDataModule", + DummyDataModule, + ), + patch("neural_lam.train_model.pl.Trainer") as mock_trainer, + patch( + "neural_lam.train_model.utils.setup_training_logger" + ) as mock_setup_logger, + ): + neural_lam.train_model.main( + [ + "--config_path", + "dummy.yaml", + "--dry_run_data", + "--ar_steps_train", + "2", + "--ar_steps_eval", + "3", + "--val_steps_to_log", + "1", + "2", + "3", + "--num_workers", + "0", + ] + ) + + mock_trainer.assert_not_called() + mock_setup_logger.assert_not_called() + + +def test_dry_run_data_preflight_failure_raises_value_error(): + """--dry_run_data fails fast when preflight detects invalid batches.""" + + class DummyBadDataModule: + def __init__(self, *args, **kwargs): + pass + + def setup(self, stage=None): + self.stage = stage + + def train_dataloader(self): + # window size = 1 + 1 + 1 = 3, forcing dim=5 is invalid + return [_make_batch(ar_steps=3, forcing_dim=5)] + + def val_dataloader(self): + return [_make_batch(ar_steps=3, forcing_dim=6)] + + def test_dataloader(self): + return [_make_batch(ar_steps=3, forcing_dim=6)] + + args = SimpleNamespace( + eval=None, + ar_steps_train=3, + ar_steps_eval=3, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + ) + + with pytest.raises(ValueError, match="forcing feature size"): + neural_lam.train_model._run_data_preflight( + data_module=DummyBadDataModule(), + args=args, + )