diff --git a/fme/ace/stepper/test_single_module_csfno.py b/fme/ace/stepper/test_single_module_csfno.py new file mode 100644 index 000000000..25c64d6a1 --- /dev/null +++ b/fme/ace/stepper/test_single_module_csfno.py @@ -0,0 +1,245 @@ +""" +Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO. + +These tests verify that the forward pass and loss computation produce identical +results regardless of spatial decomposition (nproc=1 vs model-parallel). +""" + +import dataclasses +import datetime +import os +from collections.abc import Mapping + +import numpy as np +import pytest +import torch +import xarray as xr + +from fme.ace.data_loading.batch_data import BatchData +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder +from fme.ace.stepper.single_module import ( + StepperConfig, + TrainOutput, + TrainStepper, + TrainStepperConfig, +) +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates +from fme.core.dataset_info import DatasetInfo +from fme.core.device import get_device +from fme.core.distributed.distributed import Distributed +from fme.core.loss import StepLossConfig +from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.optimization import NullOptimization, OptimizationConfig +from fme.core.registry.module import ModuleSelector +from fme.core.step import SingleModuleStepConfig, StepSelector +from fme.core.testing.regression import validate_tensor_dict +from fme.core.typing_ import EnsembleTensorDict + +DIR = os.path.abspath(os.path.dirname(__file__)) +TIMESTEP = datetime.timedelta(hours=6) + + +def get_dataset_info( + img_shape=(5, 5), +) -> DatasetInfo: + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[-2]), + lon=torch.zeros(img_shape[-1]), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) + return DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + ) + + +def _get_train_stepper( + stepper_config: StepperConfig, + dataset_info: DatasetInfo, + **train_config_kwargs, +) -> TrainStepper: + train_config = TrainStepperConfig(**train_config_kwargs) + return train_config.get_train_stepper(stepper_config, dataset_info) + + +def get_regression_stepper_and_data() -> ( + tuple[TrainStepper, BatchData, tuple[int, int]] +): + in_names = ["a", "b"] + out_names = ["b", "c"] + n_forward_steps = 2 + n_samples = 3 + img_shape = (9, 18) + device = get_device() + + all_names = list(set(in_names + out_names)) + + loss = StepLossConfig(type="AreaWeightedMSE") + + config = StepperConfig( + step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="NoiseConditionedSFNO", + config=dataclasses.asdict( + NoiseConditionedSFNOBuilder( + embed_dim=16, + num_layers=2, + noise_embed_dim=16, + noise_type="isotropic", + ) + ), + ), + in_names=in_names, + out_names=out_names, + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={n: 0.1 for n in all_names}, + stds={n: 1.1 for n in all_names}, + ), + ), + ocean=None, + ) + ), + ), + ) + + dataset_info = get_dataset_info(img_shape=img_shape) + train_stepper = _get_train_stepper(config, dataset_info, loss=loss) + data = BatchData.new_on_device( + data={ + "a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), + "b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), + "c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), + }, + time=xr.DataArray( + np.zeros((n_samples, n_forward_steps + 1)), + dims=["sample", "time"], + ), + labels=None, + epoch=0, + horizontal_dims=["lat", "lon"], + ) + data = data.scatter_spatial(img_shape) + return train_stepper, data, img_shape + + +def flatten_dict( + d: Mapping[str, Mapping[str, torch.Tensor]], +) -> dict[str, torch.Tensor]: + return_dict = {} + for k, v in d.items(): + for k2, v2 in v.items(): + return_dict[f"{k}.{k2}"] = v2 + return return_dict + + +def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]: + return_dict = {} + for k, v in data.metrics.items(): + return_dict[f"metrics.{k}"] = v + for k, v in data.gen_data.items(): + return_dict[f"gen_data.{k}"] = v + for k, v in data.target_data.items(): + assert v.shape[1] == 1 + return_dict[f"target_data.{k}"] = v + return return_dict + + +def get_train_outputs_tensor_dict( + step_1: TrainOutput, step_2: TrainOutput +) -> dict[str, torch.Tensor]: + return flatten_dict( + { + "step_1": _get_train_output_tensor_dict(step_1), + "step_2": _get_train_output_tensor_dict(step_2), + } + ) + + +@pytest.mark.parallel +def test_stepper_train_on_batch_regression(): + torch.manual_seed(0) + train_stepper, data, img_shape = get_regression_stepper_and_data() + optimization = NullOptimization() + result1 = train_stepper.train_on_batch(data, optimization) + result2 = train_stepper.train_on_batch(data, optimization) + dist = Distributed.get_instance() + for result in [result1, result2]: + result.gen_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.gen_data), img_shape) + ) + result.target_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.target_data), img_shape) + ) + output_dict = get_train_outputs_tensor_dict(result1, result2) + validate_tensor_dict( + output_dict, + os.path.join( + DIR, + "testdata/csfno_stepper_train_on_batch_regression.pt", + ), + atol=1e-4, + rtol=1e-4, + ) + + +@pytest.mark.parallel +def test_stepper_train_on_batch_with_optimization_regression(): + torch.manual_seed(0) + train_stepper, data, img_shape = get_regression_stepper_and_data() + optimization = OptimizationConfig( + optimizer_type="Adam", + lr=0.0001, + ).build(train_stepper.modules, max_epochs=1) + result1 = train_stepper.train_on_batch(data, optimization) + result2 = train_stepper.train_on_batch(data, optimization) + dist = Distributed.get_instance() + for result in [result1, result2]: + result.gen_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.gen_data), img_shape) + ) + result.target_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.target_data), img_shape) + ) + output_dict = get_train_outputs_tensor_dict(result1, result2) + validate_tensor_dict( + output_dict, + os.path.join( + DIR, + "testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt", + ), + atol=1e-2, + rtol=1e-2, + ) + + +@pytest.mark.parallel +def test_stepper_predict_regression(): + torch.manual_seed(0) + train_stepper, data, img_shape = get_regression_stepper_and_data() + stepper = train_stepper._stepper + initial_condition = data.get_start( + prognostic_names=["b"], + n_ic_timesteps=1, + ) + output, next_state = stepper.predict( + initial_condition, data, compute_derived_variables=True + ) + dist = Distributed.get_instance() + output_data = dist.gather_spatial(dict(output.data), img_shape) + next_state_data = dist.gather_spatial( + dict(next_state.as_batch_data().data), img_shape + ) + output_dict = flatten_dict({"output": output_data, "next_state": next_state_data}) + validate_tensor_dict( + output_dict, + os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"), + atol=1e-4, + rtol=1e-4, + ) diff --git a/fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt b/fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt new file mode 100644 index 000000000..c1ebe8925 Binary files /dev/null and b/fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt differ diff --git a/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt b/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt new file mode 100644 index 000000000..8e50478ab Binary files /dev/null and b/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt differ diff --git a/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt b/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt new file mode 100644 index 000000000..143c3c078 Binary files /dev/null and b/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt differ diff --git a/fme/core/distributed/model_torch_distributed.py b/fme/core/distributed/model_torch_distributed.py index 731cdf2dc..4b08eb8e3 100644 --- a/fme/core/distributed/model_torch_distributed.py +++ b/fme/core/distributed/model_torch_distributed.py @@ -23,8 +23,10 @@ import torch import torch.distributed +import torch.distributed as pt_dist import torch.nn as nn import torch_harmonics.distributed as thd +from torch.amp import custom_bwd, custom_fwd from torch.nn import SyncBatchNorm from torch.nn.parallel import DistributedDataParallel @@ -42,6 +44,35 @@ T = TypeVar("T") +class _AutogradAllReduce(torch.autograd.Function): + """Autograd-aware all-reduce (sum) for spatial parallelism. + Forward: all-reduce (sum) the input across the given process group. + Backward: identity — gradients pass through without communication. + This makes ``spatial_reduce_sum`` differentiable so that gradients + flow correctly through the loss computation path:: + AreaWeightedMSELoss → area_weighted_mean → weighted_mean + → spatial_reduce_sum (uses this function) + Without this, the raw ``torch.distributed.all_reduce`` would break + the autograd graph because it is an in-place, non-differentiable op. + """ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward( + ctx, + input: torch.Tensor, + group: torch.distributed.ProcessGroup, + ) -> torch.Tensor: + output = input.clone() + torch.distributed.all_reduce(output, group=group) + return output + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output: torch.Tensor): + return grad_output.clone(), None + + class ModelTorchDistributed(DistributedBackend): """Distributed backend with spatial model parallelism. @@ -307,23 +338,65 @@ def _device_ids(self) -> list[int] | None: def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: """Wrap with DDP over the **data** process group. - For now, we assume spatial communication is expected to be handled - inside the model layers themselves. If we need to change course, we - can revisit... + Spatial model parallelism is handled by: + - Forward: communication inside model layers (distributed SHT/iSHT) + - Backward: gradient hooks registered here that all-reduce across + spatial ranks, so every rank sees the global-mean gradient. + + ``broadcast_buffers=False`` is required because the SHT/iSHT layers + store precomputed Legendre polynomial buffers. DDP's default + buffer broadcast modifies these in-place between forward calls, + which breaks autograd's tensor-version tracking. """ if any(p.requires_grad for p in module.parameters()): if using_gpu(): output_device = [self._device_id] else: output_device = None - return DistributedDataParallel( + wrapped = DistributedDataParallel( SyncBatchNorm.convert_sync_batchnorm(module), device_ids=self._device_ids, output_device=output_device, process_group=self._data_group, + broadcast_buffers=False, ) + self._register_spatial_grad_hooks(wrapped) + return wrapped return DummyWrapper(module) + def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None: + """All-reduce gradients across spatial ranks after each backward. + + Each spatial rank only sees its local slice of the input, so its + gradient is a partial sum. This hook sums those partials so + that every rank applies the same weight update. + + The hook fires via ``register_hook`` on each parameter, which is + invoked with the per-backward gradient tensor before it is + accumulated into ``.grad`` and before DDP's data-parallel + all-reduce. The two reductions commute (orthogonal groups), so + ordering does not matter. + """ + if self._h_size <= 1 and self._w_size <= 1: + return + spatial_group = self._spatial_group + + def _hook(grad: torch.Tensor) -> torch.Tensor: + if grad is None: + return grad + + reduced = grad.contiguous().clone() + torch.distributed.all_reduce(reduced, group=spatial_group) + + # If we want mean gradient instead of sum, we want: + # reduced /= (self._h_size * self._w_size) + + return reduced + + for p in module.parameters(): + if p.requires_grad: + p.register_hook(_hook) + def barrier(self): """Global barrier across all ranks.""" logger.debug("Barrier on rank %d", self._rank) @@ -331,7 +404,7 @@ def barrier(self): def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: if self._h_size > 1 or self._w_size > 1: - torch.distributed.all_reduce(tensor, group=self._spatial_group) + return _AutogradAllReduce.apply(tensor, self._spatial_group) return tensor def weighted_mean( @@ -341,6 +414,7 @@ def weighted_mean( dim: tuple[int, ...], keepdim: bool = False, ) -> torch.Tensor: + from fme.core.metrics import weighted_sum local_weighted_sum = weighted_sum(data, weights, dim=dim, keepdim=keepdim) diff --git a/fme/core/distributed/parallel_tests/test_backward_step.py b/fme/core/distributed/parallel_tests/test_backward_step.py new file mode 100644 index 000000000..66beeff55 --- /dev/null +++ b/fme/core/distributed/parallel_tests/test_backward_step.py @@ -0,0 +1,135 @@ +import pathlib + +import numpy as np +import pytest +import torch + +import fme +from fme.core.distributed.distributed import Distributed +from fme.core.gridded_ops import LatLonOperations +from fme.core.typing_ import TensorDict + +DATA_DIR = pathlib.Path(__file__).parent / "testdata" +BASELINE_FILE = DATA_DIR / "backward_step_baseline.pt" + + +def _run_forward_backward( + img_shape: tuple[int, int], +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a single forward + backward step and return: + - loss on this rank + - global gradient + """ + dist = Distributed.get_instance() + + batch_size = 4 + n_channels = 2 + nlat, nlon = img_shape + + # Build 2D weights with correct spatial shape + lat = torch.linspace(-np.pi / 2, np.pi / 2, nlat, device=fme.get_device()) + area_weights_lat = torch.cos(lat).clamp_min(1e-3) # (nlat,) + area_weights_global = area_weights_lat.unsqueeze(-1).repeat(1, nlon) # (nlat, nlon) + + global_ops = LatLonOperations(area_weights=area_weights_global) + + # Global tensors + torch.manual_seed(0) + x_global = torch.randn( + batch_size, n_channels, nlat, nlon, device=fme.get_device(), requires_grad=True + ) + y_global = torch.randn_like(x_global) + + global_inputs: TensorDict = { + "x": x_global, + "y": y_global, + } + local_inputs = dist.scatter_spatial(global_inputs, img_shape=(nlat, nlon)) + + x_local = local_inputs["x"] + y_local = local_inputs["y"] + x_local.retain_grad() + + sht = global_ops.get_real_sht().to(fme.get_device()) + isht = global_ops.get_real_isht().to(fme.get_device()) + # Forward: x -> sht -> isht -> y_pred + y_hat_local = sht(x_local) + y_pred_local = isht(y_hat_local) + + mse = (y_pred_local - y_local) ** 2 + # Global, area-weighted MSE over spatial dims via LatLonOperations + mse_spatial = global_ops.area_weighted_mean(mse) + loss = mse_spatial.mean() + + loss.backward() + # Gather grad_x back to global grid + grad_local = x_local.grad.detach() + grad_global_dict = dist.gather_spatial({"x": grad_local}, img_shape=img_shape) + grad_x_global = grad_global_dict["x"] + + return loss.detach().cpu(), grad_x_global.cpu() + + +@pytest.mark.parametrize("img_shape", [(16, 32)]) +@pytest.mark.parallel +def test_spatial_parallel_backward_step(img_shape): + """ + Test: run forward + backward under + ModelTorchDistributed with spatial parallelism. + + Asserts: + - Loss is same with sp decomp compared with NonDistributed baseline + - Gradient is element-wise same with sp decomp compared with NonDistributed baseline + """ + dist = Distributed.get_instance() + torch.manual_seed(0) + + # Run forwards/backwards + loss, grad = _run_forward_backward(img_shape) + + # Only root does I/O + if not dist.is_root(): + return + + if not BASELINE_FILE.exists(): + # Baseline generation mode: expect non-distributed backend here. + # Save loss and grads for later regression. + BASELINE_FILE.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "img_shape": img_shape, + "loss": loss, + "grad": grad, + }, + BASELINE_FILE, + ) + return + + # Regression mode: compare against existing baseline. + baseline = torch.load(BASELINE_FILE, map_location="cpu") + assert tuple(baseline["img_shape"]) == tuple(img_shape) + + baseline_loss = baseline["loss"].item() + baseline_grad = baseline["grad"] + + # 1) Loss finite and close to baseline. + assert torch.isfinite(loss), "Loss is not finite on this rank" + + # Compare loss (scalar) with a small relative tolerance + actual_loss = loss.item() + rel_loss = abs(actual_loss - baseline_loss) / max(abs(baseline_loss), 1e-12) + assert rel_loss < 1e-6, ( + f"Loss deviates from baseline: " + f"actual={actual_loss:.8f}, expected={baseline_loss:.8f}, rel_diff={rel_loss:.3e}" + ) + max_rel = ( + ((grad - baseline_grad).abs() / baseline_grad.abs().clamp_min(1e-12)) + .max() + .item() + ) + assert torch.allclose(grad, baseline_grad, rtol=1e-6, atol=1e-7), ( + f"grad_x differs from baseline: " + f"max_abs={(grad - baseline_grad).abs().max().item():.3e}, " + f"max_rel={max_rel:.3e}" + ) diff --git a/fme/core/distributed/parallel_tests/test_step.py b/fme/core/distributed/parallel_tests/test_step.py index 6800d0b83..a052e2a43 100644 --- a/fme/core/distributed/parallel_tests/test_step.py +++ b/fme/core/distributed/parallel_tests/test_step.py @@ -18,10 +18,18 @@ import numpy as np import pytest import torch +import xarray as xr from torch import nn import fme +from fme.ace.data_loading.batch_data import BatchData from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder +from fme.ace.stepper.single_module import ( + StepperConfig, + TrainOutput, + TrainStepper, + TrainStepperConfig, +) from fme.ace.testing.fv3gfs_data import get_scalar_dataset from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig @@ -29,12 +37,20 @@ from fme.core.distributed.distributed import Distributed from fme.core.distributed.non_distributed import DummyWrapper from fme.core.labels import BatchLabels +from fme.core.loss import StepLossConfig from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.optimization import Optimization, OptimizationConfig, SchedulerConfig from fme.core.registry import ModuleSelector +from fme.core.step import SingleModuleStepConfig, StepSelector from fme.core.step.args import StepArgs from fme.core.step.multi_call import MultiCallConfig, MultiCallStepConfig from fme.core.step.secondary_decoder import SecondaryDecoderConfig -from fme.core.step.single_module import SingleModuleStepConfig + +# from fme.core.step.single_module import ( +# SingleModuleStepConfig, +# TrainOutput, +# TrainStepper, +# ) from fme.core.step.step import StepABC, StepSelector from fme.core.typing_ import TensorDict @@ -43,6 +59,32 @@ DATA_DIR = pathlib.Path(__file__).parent / "testdata" +def get_dataset_info( + img_shape=(5, 5), +) -> DatasetInfo: + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[-2]), + lon=torch.zeros(img_shape[-1]), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) + return DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + ) + + +def _get_train_stepper( + stepper_config: StepperConfig, + dataset_info: DatasetInfo, + **train_config_kwargs, +) -> TrainStepper: + train_config = TrainStepperConfig(**train_config_kwargs) + return train_config.get_train_stepper(stepper_config, dataset_info) + + def get_network_and_loss_normalization_config( names: list[str], dir: pathlib.Path | None = None, @@ -467,3 +509,197 @@ def test_step_regression( output = dist.gather_spatial(output, img_shape) cache_step_output(output, DATA_DIR / f"{case_name}_output.pt") + + +def _run_stepper_backward_with_optimization( + img_shape: tuple[int, int], + n_samples: int, +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Single forward + backward through TrainStepper using Optimization. + Returns scalar loss and per-parameter gradients (CPU tensors). + """ + torch.manual_seed(0) + device = fme.get_device() + dist = Distributed.get_instance() + + # Reuse the same config pattern as get_regression_stepper_and_data + in_names = ["a", "b"] + out_names = ["b", "c"] + n_forward_steps = 2 + all_names = list(set(in_names + out_names)) + + loss_cfg = StepLossConfig(type="AreaWeightedMSE") + + stepper_config = StepperConfig( + step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="NoiseConditionedSFNO", + config=dataclasses.asdict( + NoiseConditionedSFNOBuilder( + embed_dim=16, + num_layers=2, + noise_embed_dim=16, + noise_type="isotropic", + ) + ), + ), + in_names=in_names, + out_names=out_names, + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={n: 0.1 for n in all_names}, + stds={n: 1.1 for n in all_names}, + ), + ), + ocean=None, + ) + ), + ), + ) + + dataset_info = get_dataset_info(img_shape=img_shape) + + train_stepper: TrainStepper = _get_train_stepper( + stepper_config, + dataset_info, + loss=loss_cfg, + ) + + # Random data, same as regression helper + data = BatchData.new_on_device( + data={ + "a": torch.randn(n_samples, n_forward_steps + 1, *img_shape, device=device), + "b": torch.randn(n_samples, n_forward_steps + 1, *img_shape, device=device), + "c": torch.randn(n_samples, n_forward_steps + 1, *img_shape, device=device), + }, + time=xr.DataArray( + np.zeros((n_samples, n_forward_steps + 1)), + dims=["sample", "time"], + ), + labels=None, + epoch=0, + horizontal_dims=["lat", "lon"], + ) + data = data.scatter_spatial(img_shape) + + # Build Optimization on train_stepper.modules + opt_config = OptimizationConfig( + optimizer_type="Adam", + lr=1e-3, + use_gradient_accumulation=True, + enable_automatic_mixed_precision=False, + ) + + optimization = opt_config.build(train_stepper.modules, max_epochs=1) + optimization.set_mode(train_stepper.modules) + + # Forward + backward through the *training* API + # train_output: TrainOutput = train_stepper.train_on_batch(data, optimization) + + # --- Manual version of train_on_batch, but WITHOUT step_weights() --- + train_stepper._init_for_epoch(data.epoch) + metrics: dict[str, torch.Tensor] = {} + + input_data = data.get_start(train_stepper._prognostic_names, train_stepper.n_ic_timesteps) + target_data = train_stepper._stepper.get_forward_data( + data, compute_derived_variables=False + ) + data = train_stepper._stepper.forcing_deriver(data) + + optimization.set_mode(train_stepper._stepper.modules) + + output_list = train_stepper._accumulate_loss( + input_data=input_data, + data=data, + target_data=target_data, + optimization=optimization, + metrics=metrics, + ) + + regularizer_loss = train_stepper._stepper.get_regularizer_loss() + if torch.any(regularizer_loss > 0): + optimization.accumulate_loss(regularizer_loss) + + loss = optimization.get_accumulated_loss() + + grads: dict[str, torch.Tensor] = {} + for i, wrapped in enumerate(train_stepper.modules): + module = getattr(wrapped, "module", wrapped) + for name, p in module.named_parameters(): + if p.grad is not None: + grads[f"module_{i}.{name}"] = p.grad.detach().cpu().clone() + + + return loss.detach().cpu(), grads + + +@pytest.mark.parallel +def test_stepper_backward_with_optimization(): + """ + Test compares gradients after backward step with and without spatial parallelism. + + Since each rank holds the entire global model's parameters, there is no need to gather. + This test will need to be modified once spatial sharding is implemented for parameters. + """ + DATA_DIR = pathlib.Path(__file__).parent / "testdata" + BASELINE_FILE = DATA_DIR / "csfno_stepper_backward_with_opt_baseline.pt" + dist = Distributed.get_instance() + torch.manual_seed(0) + + img_shape = (20, 40) + n_samples = 2 + + loss, grads = _run_stepper_backward_with_optimization(img_shape, n_samples) + + if not dist.is_root(): + return + + if not BASELINE_FILE.exists(): + BASELINE_FILE.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "img_shape": img_shape, + "n_samples": n_samples, + "loss": loss, + "grads": grads, + }, + BASELINE_FILE, + ) + print("Created Baseline file") + return + + baseline = torch.load(BASELINE_FILE, map_location="cpu") + assert tuple(baseline["img_shape"]) == tuple(img_shape) + assert baseline["n_samples"] == n_samples + + baseline_loss = baseline["loss"] + baseline_grads: dict[str, torch.Tensor] = baseline["grads"] + + # Loss check + assert torch.isfinite(loss), "Loss is not finite on this rank" + actual_loss = loss.item() + expected_loss = baseline_loss.item() + rel_loss = abs(actual_loss - expected_loss) / max(abs(expected_loss), 1e-12) + assert rel_loss < 1e-6, ( + f"Loss deviates from baseline: " + f"actual={actual_loss:.8e}, expected={expected_loss:.8e}, " + f"rel_diff={rel_loss:.3e}" + ) + + # Grad check + assert set(grads.keys()) == set(baseline_grads.keys()) + for name in sorted(grads.keys()): + g = grads[name] + g_ref = baseline_grads[name] + assert g.shape == g_ref.shape, f"Shape mismatch for grad '{name}'" + diff = (g - g_ref).abs() + max_abs = diff.max().item() + max_rel = (diff / g_ref.abs().clamp_min(1e-12)).max().item() + assert torch.allclose(g, g_ref, rtol=1e-6, atol=1e-8), ( + f"Gradient for '{name}' deviates from baseline: " + f"max_abs={max_abs:.3e}, max_rel={max_rel:.3e}" + ) diff --git a/fme/core/distributed/parallel_tests/testdata/backward_step_baseline.pt b/fme/core/distributed/parallel_tests/testdata/backward_step_baseline.pt new file mode 100644 index 000000000..b876c651e Binary files /dev/null and b/fme/core/distributed/parallel_tests/testdata/backward_step_baseline.pt differ diff --git a/scripts/testing/test_spatial.sh b/scripts/testing/test_spatial.sh new file mode 100755 index 000000000..3ada4026c --- /dev/null +++ b/scripts/testing/test_spatial.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash +set -euo pipefail +set -x + +H_ARG=${1:-} +W_ARG=${2:-} + +H=${H_ARG:-${FME_DISTRIBUTED_H:-2}} +W=${W_ARG:-${FME_DISTRIBUTED_W:-2}} + +NP=$((H * W)) + +dir=fme/core/distributed/parallel_tests +# dir=fme/ace/stepper/ +tests=test_backward_step.py::test_spatial_parallel_backward_step +# tests=test_step.py::test_stepper_backward_with_optimization +# tests=test_single_module_csfno.py::test_stepper_train_on_batch_with_optimization_regression +pytest_cmd="pytest -s $dir/$tests" +file="testdata/backward_step_baseline.pt" +# file="testdata/csfno_stepper_backward_with_opt_baseline.pt" + +if [ -f "$dir/$file" ]; then + rm "$dir/$file" +fi + +$pytest_cmd + +export FME_DISTRIBUTED_BACKEND=model +export FME_DISTRIBUTED_H=$H +export FME_DISTRIBUTED_W=$W +torchrun --standalone --nnodes=1 --nproc-per-node=$NP -m $pytest_cmd +