Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar

- Add `--dry_run_data` mode to `train_model` to preflight-validate dataloader batches for pipeline-specific temporal alignment and forcing window consistency before creating model/trainer [\#510](https://github.com/mllam/neural-lam/issues/510) @AR10129

### Changed

- Change the default ensemble-loading behavior in `WeatherDataset` / `WeatherDataModule` to use all ensemble members as independent samples for ensemble datastores (with matching ensemble-member selection for forcing when available); single-member behavior now requires explicitly opting in via `--load_single_member` [\#332](https://github.com/mllam/neural-lam/pull/332) @kshirajahere
Expand Down
119 changes: 119 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import time
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from typing import Iterable

# Third-party
# for logging the model:
Expand All @@ -24,6 +25,111 @@
}


def _validate_preflight_batch(
batch,
split_name: str,
expected_ar_steps: int,
forcing_window_size: int,
):
"""Validate one batch for pipeline-specific temporal/forcing invariants."""
if not isinstance(batch, (tuple, list)) or len(batch) != 4:
raise ValueError(
f"{split_name}: expected batch to be a 4-tuple "
"(init_states, target_states, forcing, target_times)."
)

init_states, target_states, forcing, target_times = batch

if not all(isinstance(t, torch.Tensor) for t in batch):
raise ValueError(f"{split_name}: all batch entries must be tensors.")

if target_states.shape[1] != expected_ar_steps:
raise ValueError(
f"{split_name}: target_states time dimension must match "
f"expected_ar_steps={expected_ar_steps}, got "
f"{target_states.shape[1]}."
)

if forcing.shape[1] != expected_ar_steps:
raise ValueError(
f"{split_name}: forcing time dimension must match "
f"expected_ar_steps={expected_ar_steps}, got "
f"{forcing.shape[1]}."
)

if target_times.shape[1] != expected_ar_steps:
raise ValueError(
f"{split_name}: target_times length must match "
f"expected_ar_steps={expected_ar_steps}, got "
f"{target_times.shape[1]}."
)

if forcing.shape[-1] % forcing_window_size != 0:
raise ValueError(
f"{split_name}: forcing feature size ({forcing.shape[-1]}) is "
f"not divisible by forcing window size ({forcing_window_size})."
)


def _validate_preflight_loader(
loader: Iterable,
split_name: str,
expected_ar_steps: int,
forcing_window_size: int,
):
"""Validate the first batch produced by a dataloader."""
try:
batch = next(iter(loader))
except StopIteration as exc:
raise ValueError(f"{split_name}: dataloader is empty.") from exc

_validate_preflight_batch(
batch=batch,
split_name=split_name,
expected_ar_steps=expected_ar_steps,
forcing_window_size=forcing_window_size,
)


def _run_data_preflight(data_module: WeatherDataModule, args):
"""Run one-batch data pipeline checks and fail fast on invalid data."""
forcing_window_size = (
args.num_past_forcing_steps + args.num_future_forcing_steps + 1
)

if args.eval:
data_module.setup(stage="test")
_validate_preflight_loader(
loader=data_module.test_dataloader(),
split_name=f"eval_{args.eval}",
expected_ar_steps=args.ar_steps_eval,
forcing_window_size=forcing_window_size,
)
logger.info(
f"Data preflight passed for eval split '{args.eval}' with "
f"ar_steps={args.ar_steps_eval}."
)
else:
data_module.setup(stage="fit")
_validate_preflight_loader(
loader=data_module.train_dataloader(),
split_name="train",
expected_ar_steps=args.ar_steps_train,
forcing_window_size=forcing_window_size,
)
_validate_preflight_loader(
loader=data_module.val_dataloader(),
split_name="val",
expected_ar_steps=args.ar_steps_eval,
forcing_window_size=forcing_window_size,
)
logger.info(
"Data preflight passed for train/val splits with "
f"ar_steps_train={args.ar_steps_train} and "
f"ar_steps_eval={args.ar_steps_eval}."
)


@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
Expand Down Expand Up @@ -244,6 +350,14 @@ def main(input_args=None):
"ensemble members as independent samples."
),
)
parser.add_argument(
"--dry_run_data",
action="store_true",
help=(
"Validate one batch from each relevant dataloader and exit "
"without creating model/trainer or running train/test."
),
)
args = parser.parse_args(input_args)
args.var_leads_metrics_watch = {
int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items()
Expand Down Expand Up @@ -297,6 +411,11 @@ def main(input_args=None):
eval_split=args.eval or "test",
)

if args.dry_run_data:
_run_data_preflight(data_module=data_module, args=args)
logger.info("Exiting after successful data preflight.")
return

# Instantiate model + trainer
if torch.cuda.is_available():
device_name = "cuda"
Expand Down
113 changes: 110 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# Standard library
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

# Third-party
import pytest
import torch

# First-party
import neural_lam
import neural_lam.create_graph
import neural_lam.train_model


def test_import():
"""This test just ensures that each cli entry-point can be imported for now,
eventually we should test their execution too."""
def test_cli_entrypoints_importable():
"""Each CLI entry-point can be imported."""
assert neural_lam is not None
assert neural_lam.create_graph is not None
assert neural_lam.train_model is not None
Expand Down Expand Up @@ -117,3 +118,109 @@ def test_wandb_id_ignored_with_mlflow_warns():
warning_msg = mock_log.warning.call_args[0][0]
assert "--wandb_id is set but logger is" in warning_msg
assert "mlflow" in warning_msg


def _make_batch(ar_steps, forcing_dim):
"""Create a synthetic weather batch with DataLoader-like batch dim."""
batch_size = 2
n_grid = 5
d_state = 3

init_states = torch.randn(batch_size, 2, n_grid, d_state)
target_states = torch.randn(batch_size, ar_steps, n_grid, d_state)
forcing = torch.randn(batch_size, ar_steps, n_grid, forcing_dim)
base = torch.arange(ar_steps, dtype=torch.int64)
target_times = torch.stack([base, base + 10], dim=0)

return init_states, target_states, forcing, target_times


def test_dry_run_data_preflight_success_skips_training_setup():
"""--dry_run_data runs preflight checks and exits before trainer setup."""

class DummyDataModule:
def __init__(self, *args, **kwargs):
pass

def setup(self, stage=None):
self.stage = stage

def train_dataloader(self):
# window size = 1 + 1 + 1 = 3, forcing dim=6 is valid
return [_make_batch(ar_steps=2, forcing_dim=6)]

def val_dataloader(self):
return [_make_batch(ar_steps=3, forcing_dim=6)]

def test_dataloader(self):
return [_make_batch(ar_steps=3, forcing_dim=6)]

with (
patch(
"neural_lam.train_model.load_config_and_datastore",
return_value=(MagicMock(), MagicMock()),
),
patch(
"neural_lam.train_model.WeatherDataModule",
DummyDataModule,
),
patch("neural_lam.train_model.pl.Trainer") as mock_trainer,
patch(
"neural_lam.train_model.utils.setup_training_logger"
) as mock_setup_logger,
):
neural_lam.train_model.main(
[
"--config_path",
"dummy.yaml",
"--dry_run_data",
"--ar_steps_train",
"2",
"--ar_steps_eval",
"3",
"--val_steps_to_log",
"1",
"2",
"3",
"--num_workers",
"0",
]
)

mock_trainer.assert_not_called()
mock_setup_logger.assert_not_called()


def test_dry_run_data_preflight_failure_raises_value_error():
"""--dry_run_data fails fast when preflight detects invalid batches."""

class DummyBadDataModule:
def __init__(self, *args, **kwargs):
pass

def setup(self, stage=None):
self.stage = stage

def train_dataloader(self):
# window size = 1 + 1 + 1 = 3, forcing dim=5 is invalid
return [_make_batch(ar_steps=3, forcing_dim=5)]

def val_dataloader(self):
return [_make_batch(ar_steps=3, forcing_dim=6)]

def test_dataloader(self):
return [_make_batch(ar_steps=3, forcing_dim=6)]

args = SimpleNamespace(
eval=None,
ar_steps_train=3,
ar_steps_eval=3,
num_past_forcing_steps=1,
num_future_forcing_steps=1,
)

with pytest.raises(ValueError, match="forcing feature size"):
neural_lam.train_model._run_data_preflight(
data_module=DummyBadDataModule(),
args=args,
)