-
Notifications
You must be signed in to change notification settings - Fork 38
add backward pass for spatial prallellism #993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
4da9aee
08647d8
2f85d7a
bee34f4
85bf4ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| """ | ||
| Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO. | ||
|
|
||
| These tests verify that the forward pass and loss computation produce identical | ||
| results regardless of spatial decomposition (nproc=1 vs model-parallel). | ||
| """ | ||
|
|
||
| import dataclasses | ||
| import datetime | ||
| import os | ||
| from collections.abc import Mapping | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| import xarray as xr | ||
|
|
||
| from fme.ace.data_loading.batch_data import BatchData | ||
| from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder | ||
| from fme.ace.stepper.single_module import ( | ||
| StepperConfig, | ||
| TrainOutput, | ||
| TrainStepper, | ||
| TrainStepperConfig, | ||
| ) | ||
| from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates | ||
| from fme.core.dataset_info import DatasetInfo | ||
| from fme.core.device import get_device | ||
| from fme.core.distributed.distributed import Distributed | ||
| from fme.core.loss import StepLossConfig | ||
| from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig | ||
| from fme.core.optimization import NullOptimization, OptimizationConfig | ||
| from fme.core.registry.module import ModuleSelector | ||
| from fme.core.step import SingleModuleStepConfig, StepSelector | ||
| from fme.core.testing.regression import validate_tensor_dict | ||
| from fme.core.typing_ import EnsembleTensorDict | ||
|
|
||
| DIR = os.path.abspath(os.path.dirname(__file__)) | ||
| TIMESTEP = datetime.timedelta(hours=6) | ||
|
|
||
|
|
||
| def get_dataset_info( | ||
| img_shape=(5, 5), | ||
| ) -> DatasetInfo: | ||
| horizontal_coordinate = LatLonCoordinates( | ||
| lat=torch.zeros(img_shape[-2]), | ||
| lon=torch.zeros(img_shape[-1]), | ||
| ) | ||
| vertical_coordinate = HybridSigmaPressureCoordinate( | ||
| ak=torch.arange(7), bk=torch.arange(7) | ||
| ) | ||
| return DatasetInfo( | ||
| horizontal_coordinates=horizontal_coordinate, | ||
| vertical_coordinate=vertical_coordinate, | ||
| timestep=TIMESTEP, | ||
| ) | ||
|
|
||
|
|
||
| def _get_train_stepper( | ||
| stepper_config: StepperConfig, | ||
| dataset_info: DatasetInfo, | ||
| **train_config_kwargs, | ||
| ) -> TrainStepper: | ||
| train_config = TrainStepperConfig(**train_config_kwargs) | ||
| return train_config.get_train_stepper(stepper_config, dataset_info) | ||
|
|
||
|
|
||
| def get_regression_stepper_and_data() -> ( | ||
| tuple[TrainStepper, BatchData, tuple[int, int]] | ||
| ): | ||
| in_names = ["a", "b"] | ||
| out_names = ["b", "c"] | ||
| n_forward_steps = 2 | ||
| n_samples = 3 | ||
| img_shape = (9, 18) | ||
| device = get_device() | ||
|
|
||
| all_names = list(set(in_names + out_names)) | ||
|
|
||
| loss = StepLossConfig(type="AreaWeightedMSE") | ||
|
|
||
| config = StepperConfig( | ||
| step=StepSelector( | ||
| type="single_module", | ||
| config=dataclasses.asdict( | ||
| SingleModuleStepConfig( | ||
| builder=ModuleSelector( | ||
| type="NoiseConditionedSFNO", | ||
| config=dataclasses.asdict( | ||
| NoiseConditionedSFNOBuilder( | ||
| embed_dim=16, | ||
| num_layers=2, | ||
| noise_embed_dim=16, | ||
| noise_type="isotropic", | ||
| ) | ||
| ), | ||
| ), | ||
| in_names=in_names, | ||
| out_names=out_names, | ||
| normalization=NetworkAndLossNormalizationConfig( | ||
| network=NormalizationConfig( | ||
| means={n: 0.1 for n in all_names}, | ||
| stds={n: 1.1 for n in all_names}, | ||
| ), | ||
| ), | ||
| ocean=None, | ||
| ) | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
| dataset_info = get_dataset_info(img_shape=img_shape) | ||
| train_stepper = _get_train_stepper(config, dataset_info, loss=loss) | ||
| data = BatchData.new_on_device( | ||
| data={ | ||
| "a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), | ||
| "b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), | ||
| "c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), | ||
| }, | ||
| time=xr.DataArray( | ||
| np.zeros((n_samples, n_forward_steps + 1)), | ||
| dims=["sample", "time"], | ||
| ), | ||
| labels=None, | ||
| epoch=0, | ||
| horizontal_dims=["lat", "lon"], | ||
| ) | ||
| data = data.scatter_spatial(img_shape) | ||
| return train_stepper, data, img_shape | ||
|
|
||
|
|
||
| def flatten_dict( | ||
| d: Mapping[str, Mapping[str, torch.Tensor]], | ||
| ) -> dict[str, torch.Tensor]: | ||
| return_dict = {} | ||
| for k, v in d.items(): | ||
| for k2, v2 in v.items(): | ||
| return_dict[f"{k}.{k2}"] = v2 | ||
| return return_dict | ||
|
|
||
|
|
||
| def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]: | ||
| return_dict = {} | ||
| for k, v in data.metrics.items(): | ||
| return_dict[f"metrics.{k}"] = v | ||
| for k, v in data.gen_data.items(): | ||
| return_dict[f"gen_data.{k}"] = v | ||
| for k, v in data.target_data.items(): | ||
| assert v.shape[1] == 1 | ||
| return_dict[f"target_data.{k}"] = v | ||
| return return_dict | ||
|
|
||
|
|
||
| def get_train_outputs_tensor_dict( | ||
| step_1: TrainOutput, step_2: TrainOutput | ||
| ) -> dict[str, torch.Tensor]: | ||
| return flatten_dict( | ||
| { | ||
| "step_1": _get_train_output_tensor_dict(step_1), | ||
| "step_2": _get_train_output_tensor_dict(step_2), | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parallel | ||
| def test_stepper_train_on_batch_regression(): | ||
| torch.manual_seed(0) | ||
| train_stepper, data, img_shape = get_regression_stepper_and_data() | ||
| optimization = NullOptimization() | ||
| result1 = train_stepper.train_on_batch(data, optimization) | ||
| result2 = train_stepper.train_on_batch(data, optimization) | ||
| dist = Distributed.get_instance() | ||
| for result in [result1, result2]: | ||
| result.gen_data = EnsembleTensorDict( | ||
| dist.gather_spatial(dict(result.gen_data), img_shape) | ||
| ) | ||
| result.target_data = EnsembleTensorDict( | ||
| dist.gather_spatial(dict(result.target_data), img_shape) | ||
| ) | ||
| output_dict = get_train_outputs_tensor_dict(result1, result2) | ||
| validate_tensor_dict( | ||
| output_dict, | ||
| os.path.join( | ||
| DIR, | ||
| "testdata/csfno_stepper_train_on_batch_regression.pt", | ||
| ), | ||
| atol=1e-4, | ||
| rtol=1e-4, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parallel | ||
| def test_stepper_train_on_batch_with_optimization_regression(): | ||
| torch.manual_seed(0) | ||
| train_stepper, data, img_shape = get_regression_stepper_and_data() | ||
| optimization = OptimizationConfig( | ||
| optimizer_type="Adam", | ||
| lr=0.0001, | ||
| ).build(train_stepper.modules, max_epochs=1) | ||
| result1 = train_stepper.train_on_batch(data, optimization) | ||
| result2 = train_stepper.train_on_batch(data, optimization) | ||
| dist = Distributed.get_instance() | ||
| for result in [result1, result2]: | ||
| result.gen_data = EnsembleTensorDict( | ||
| dist.gather_spatial(dict(result.gen_data), img_shape) | ||
| ) | ||
| result.target_data = EnsembleTensorDict( | ||
| dist.gather_spatial(dict(result.target_data), img_shape) | ||
| ) | ||
| output_dict = get_train_outputs_tensor_dict(result1, result2) | ||
| validate_tensor_dict( | ||
| output_dict, | ||
| os.path.join( | ||
| DIR, | ||
| "testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt", | ||
| ), | ||
| atol=1e-2, | ||
| rtol=1e-2, | ||
|
Comment on lines
+217
to
+218
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems too low, and thus I have low confidence in all of this for now |
||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parallel | ||
| def test_stepper_predict_regression(): | ||
| torch.manual_seed(0) | ||
| train_stepper, data, img_shape = get_regression_stepper_and_data() | ||
| stepper = train_stepper._stepper | ||
| initial_condition = data.get_start( | ||
| prognostic_names=["b"], | ||
| n_ic_timesteps=1, | ||
| ) | ||
| output, next_state = stepper.predict( | ||
| initial_condition, data, compute_derived_variables=True | ||
| ) | ||
| dist = Distributed.get_instance() | ||
| output_data = dist.gather_spatial(dict(output.data), img_shape) | ||
| next_state_data = dist.gather_spatial( | ||
| dict(next_state.as_batch_data().data), img_shape | ||
| ) | ||
| output_dict = flatten_dict({"output": output_data, "next_state": next_state_data}) | ||
| validate_tensor_dict( | ||
| output_dict, | ||
| os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"), | ||
| atol=1e-4, | ||
| rtol=1e-4, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| import torch.distributed | ||
| import torch.nn as nn | ||
| import torch_harmonics.distributed as thd | ||
| from torch.amp import custom_bwd, custom_fwd | ||
| from torch.nn import SyncBatchNorm | ||
| from torch.nn.parallel import DistributedDataParallel | ||
|
|
||
|
|
@@ -42,6 +43,35 @@ | |
| T = TypeVar("T") | ||
|
|
||
|
|
||
| class _AutogradAllReduce(torch.autograd.Function): | ||
| """Autograd-aware all-reduce (sum) for spatial parallelism. | ||
| Forward: all-reduce (sum) the input across the given process group. | ||
| Backward: identity — gradients pass through without communication. | ||
| This makes ``spatial_reduce_sum`` differentiable so that gradients | ||
| flow correctly through the loss computation path:: | ||
| AreaWeightedMSELoss → area_weighted_mean → weighted_mean | ||
| → spatial_reduce_sum (uses this function) | ||
| Without this, the raw ``torch.distributed.all_reduce`` would break | ||
| the autograd graph because it is an in-place, non-differentiable op. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| @custom_fwd(device_type="cuda") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is custom_fwd needed? Don't you need this on CPU? Is this why CPU and GPU are giving different results?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea, I saw this in the makani/physicsnemo repos and I copied it blindly (testing with it and without didn't really make any difference, or at least I didn't notice it)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They don’t have spatial parallelism on cpu, so it wouldn’t cause issues for them, but we should remove it. I did in my branch incorporating this code. |
||
| def forward( | ||
| ctx, | ||
| input: torch.Tensor, | ||
| group: torch.distributed.ProcessGroup, | ||
| ) -> torch.Tensor: | ||
| output = input.clone() | ||
| torch.distributed.all_reduce(output, group=group) | ||
| return output | ||
|
|
||
| @staticmethod | ||
| @custom_bwd(device_type="cuda") | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| return grad_output.clone(), None | ||
|
Comment on lines
+66
to
+73
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the cloning is probably unnecessary here |
||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. identity in backward may become an issue... we may need an all-reduce in backward, but that's kind of what the hook below is doing |
||
|
|
||
| class ModelTorchDistributed(DistributedBackend): | ||
| """Distributed backend with spatial model parallelism. | ||
|
|
||
|
|
@@ -307,31 +337,73 @@ def _device_ids(self) -> list[int] | None: | |
| def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| """Wrap with DDP over the **data** process group. | ||
|
|
||
| For now, we assume spatial communication is expected to be handled | ||
| inside the model layers themselves. If we need to change course, we | ||
| can revisit... | ||
| Spatial model parallelism is handled by: | ||
| - Forward: communication inside model layers (distributed SHT/iSHT) | ||
| - Backward: gradient hooks registered here that all-reduce across | ||
| spatial ranks, so every rank sees the global-mean gradient. | ||
|
|
||
| ``broadcast_buffers=False`` is required because the SHT/iSHT layers | ||
| store precomputed Legendre polynomial buffers. DDP's default | ||
| buffer broadcast modifies these in-place between forward calls, | ||
| which breaks autograd's tensor-version tracking. | ||
| """ | ||
| if any(p.requires_grad for p in module.parameters()): | ||
| if using_gpu(): | ||
| output_device = [self._device_id] | ||
| else: | ||
| output_device = None | ||
| return DistributedDataParallel( | ||
| wrapped = DistributedDataParallel( | ||
| SyncBatchNorm.convert_sync_batchnorm(module), | ||
| device_ids=self._device_ids, | ||
| output_device=output_device, | ||
| process_group=self._data_group, | ||
| broadcast_buffers=False, | ||
| ) | ||
| self._register_spatial_grad_hooks(wrapped) | ||
| return wrapped | ||
| return DummyWrapper(module) | ||
|
|
||
| def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None: | ||
| """All-reduce gradients across spatial ranks after each backward. | ||
|
|
||
| Each spatial rank only sees its local slice of the input, so its | ||
| gradient is a partial sum. This hook sums those partials so | ||
| that every rank applies the same weight update. | ||
|
|
||
| The hook fires via ``register_hook`` on each parameter, which is | ||
| invoked with the per-backward gradient tensor before it is | ||
| accumulated into ``.grad`` and before DDP's data-parallel | ||
| all-reduce. The two reductions commute (orthogonal groups), so | ||
| ordering does not matter. | ||
| """ | ||
| if self._h_size <= 1 and self._w_size <= 1: | ||
| return | ||
| spatial_group = self._spatial_group | ||
|
|
||
| def _hook(grad: torch.Tensor) -> torch.Tensor: | ||
| if grad is None: | ||
| return grad | ||
|
|
||
| reduced = grad.contiguous().clone() | ||
| torch.distributed.all_reduce(reduced, group=spatial_group) | ||
|
|
||
| # If we want mean gradient instead of sum, we want: | ||
| # reduced /= (self._h_size * self._w_size) | ||
|
Comment on lines
+391
to
+392
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't quite wrap my head around which one we really need tbh, and I think this is linked with how we do losses here... |
||
|
|
||
| return reduced | ||
|
|
||
| for p in module.parameters(): | ||
| if p.requires_grad: | ||
| p.register_hook(_hook) | ||
|
|
||
| def barrier(self): | ||
| """Global barrier across all ranks.""" | ||
| logger.debug("Barrier on rank %d", self._rank) | ||
| torch.distributed.barrier(device_ids=self._device_ids) | ||
|
|
||
| def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: | ||
| if self._h_size > 1 or self._w_size > 1: | ||
| torch.distributed.all_reduce(tensor, group=self._spatial_group) | ||
| return _AutogradAllReduce.apply(tensor, self._spatial_group) | ||
| return tensor | ||
|
|
||
| def weighted_mean( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot of duplicated testing logic, do we need a new testing file or can we make the existing test in test_single_module.py parallel-enabled / put this parallel-enabled test in that file to use its helpers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That test uses the legacy SFNO --- could we change it to use CSFNO?