Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions fme/ace/stepper/test_single_module_csfno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""
Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO.

These tests verify that the forward pass and loss computation produce identical
results regardless of spatial decomposition (nproc=1 vs model-parallel).
"""

import dataclasses
import datetime
import os
from collections.abc import Mapping

import numpy as np
import pytest
import torch
import xarray as xr

from fme.ace.data_loading.batch_data import BatchData
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder
from fme.ace.stepper.single_module import (
StepperConfig,
TrainOutput,
TrainStepper,
TrainStepperConfig,
)
from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
from fme.core.dataset_info import DatasetInfo
from fme.core.device import get_device
from fme.core.distributed.distributed import Distributed
from fme.core.loss import StepLossConfig
from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig
from fme.core.optimization import NullOptimization, OptimizationConfig
from fme.core.registry.module import ModuleSelector
from fme.core.step import SingleModuleStepConfig, StepSelector
from fme.core.testing.regression import validate_tensor_dict
from fme.core.typing_ import EnsembleTensorDict

DIR = os.path.abspath(os.path.dirname(__file__))
TIMESTEP = datetime.timedelta(hours=6)


def get_dataset_info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of duplicated testing logic, do we need a new testing file or can we make the existing test in test_single_module.py parallel-enabled / put this parallel-enabled test in that file to use its helpers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That test uses the legacy SFNO --- could we change it to use CSFNO?

img_shape=(5, 5),
) -> DatasetInfo:
horizontal_coordinate = LatLonCoordinates(
lat=torch.zeros(img_shape[-2]),
lon=torch.zeros(img_shape[-1]),
)
vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.arange(7), bk=torch.arange(7)
)
return DatasetInfo(
horizontal_coordinates=horizontal_coordinate,
vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)


def _get_train_stepper(
stepper_config: StepperConfig,
dataset_info: DatasetInfo,
**train_config_kwargs,
) -> TrainStepper:
train_config = TrainStepperConfig(**train_config_kwargs)
return train_config.get_train_stepper(stepper_config, dataset_info)


def get_regression_stepper_and_data() -> (
tuple[TrainStepper, BatchData, tuple[int, int]]
):
in_names = ["a", "b"]
out_names = ["b", "c"]
n_forward_steps = 2
n_samples = 3
img_shape = (9, 18)
device = get_device()

all_names = list(set(in_names + out_names))

loss = StepLossConfig(type="AreaWeightedMSE")

config = StepperConfig(
step=StepSelector(
type="single_module",
config=dataclasses.asdict(
SingleModuleStepConfig(
builder=ModuleSelector(
type="NoiseConditionedSFNO",
config=dataclasses.asdict(
NoiseConditionedSFNOBuilder(
embed_dim=16,
num_layers=2,
noise_embed_dim=16,
noise_type="isotropic",
)
),
),
in_names=in_names,
out_names=out_names,
normalization=NetworkAndLossNormalizationConfig(
network=NormalizationConfig(
means={n: 0.1 for n in all_names},
stds={n: 1.1 for n in all_names},
),
),
ocean=None,
)
),
),
)

dataset_info = get_dataset_info(img_shape=img_shape)
train_stepper = _get_train_stepper(config, dataset_info, loss=loss)
data = BatchData.new_on_device(
data={
"a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
"b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
"c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
},
time=xr.DataArray(
np.zeros((n_samples, n_forward_steps + 1)),
dims=["sample", "time"],
),
labels=None,
epoch=0,
horizontal_dims=["lat", "lon"],
)
data = data.scatter_spatial(img_shape)
return train_stepper, data, img_shape


def flatten_dict(
d: Mapping[str, Mapping[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
return_dict = {}
for k, v in d.items():
for k2, v2 in v.items():
return_dict[f"{k}.{k2}"] = v2
return return_dict


def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]:
return_dict = {}
for k, v in data.metrics.items():
return_dict[f"metrics.{k}"] = v
for k, v in data.gen_data.items():
return_dict[f"gen_data.{k}"] = v
for k, v in data.target_data.items():
assert v.shape[1] == 1
return_dict[f"target_data.{k}"] = v
return return_dict


def get_train_outputs_tensor_dict(
step_1: TrainOutput, step_2: TrainOutput
) -> dict[str, torch.Tensor]:
return flatten_dict(
{
"step_1": _get_train_output_tensor_dict(step_1),
"step_2": _get_train_output_tensor_dict(step_2),
}
)


@pytest.mark.parallel
def test_stepper_train_on_batch_regression():
torch.manual_seed(0)
train_stepper, data, img_shape = get_regression_stepper_and_data()
optimization = NullOptimization()
result1 = train_stepper.train_on_batch(data, optimization)
result2 = train_stepper.train_on_batch(data, optimization)
dist = Distributed.get_instance()
for result in [result1, result2]:
result.gen_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.gen_data), img_shape)
)
result.target_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.target_data), img_shape)
)
output_dict = get_train_outputs_tensor_dict(result1, result2)
validate_tensor_dict(
output_dict,
os.path.join(
DIR,
"testdata/csfno_stepper_train_on_batch_regression.pt",
),
atol=1e-4,
rtol=1e-4,
)


@pytest.mark.parallel
def test_stepper_train_on_batch_with_optimization_regression():
torch.manual_seed(0)
train_stepper, data, img_shape = get_regression_stepper_and_data()
optimization = OptimizationConfig(
optimizer_type="Adam",
lr=0.0001,
).build(train_stepper.modules, max_epochs=1)
result1 = train_stepper.train_on_batch(data, optimization)
result2 = train_stepper.train_on_batch(data, optimization)
dist = Distributed.get_instance()
for result in [result1, result2]:
result.gen_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.gen_data), img_shape)
)
result.target_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.target_data), img_shape)
)
output_dict = get_train_outputs_tensor_dict(result1, result2)
validate_tensor_dict(
output_dict,
os.path.join(
DIR,
"testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt",
),
atol=1e-2,
rtol=1e-2,
Comment on lines +217 to +218
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems too low, and thus I have low confidence in all of this for now

)


@pytest.mark.parallel
def test_stepper_predict_regression():
torch.manual_seed(0)
train_stepper, data, img_shape = get_regression_stepper_and_data()
stepper = train_stepper._stepper
initial_condition = data.get_start(
prognostic_names=["b"],
n_ic_timesteps=1,
)
output, next_state = stepper.predict(
initial_condition, data, compute_derived_variables=True
)
dist = Distributed.get_instance()
output_data = dist.gather_spatial(dict(output.data), img_shape)
next_state_data = dist.gather_spatial(
dict(next_state.as_batch_data().data), img_shape
)
output_dict = flatten_dict({"output": output_data, "next_state": next_state_data})
validate_tensor_dict(
output_dict,
os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"),
atol=1e-4,
rtol=1e-4,
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
82 changes: 77 additions & 5 deletions fme/core/distributed/model_torch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch.distributed
import torch.nn as nn
import torch_harmonics.distributed as thd
from torch.amp import custom_bwd, custom_fwd
from torch.nn import SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel

Expand All @@ -42,6 +43,35 @@
T = TypeVar("T")


class _AutogradAllReduce(torch.autograd.Function):
"""Autograd-aware all-reduce (sum) for spatial parallelism.
Forward: all-reduce (sum) the input across the given process group.
Backward: identity — gradients pass through without communication.
This makes ``spatial_reduce_sum`` differentiable so that gradients
flow correctly through the loss computation path::
AreaWeightedMSELoss → area_weighted_mean → weighted_mean
→ spatial_reduce_sum (uses this function)
Without this, the raw ``torch.distributed.all_reduce`` would break
the autograd graph because it is an in-place, non-differentiable op.
"""

@staticmethod
@custom_fwd(device_type="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is custom_fwd needed? Don't you need this on CPU? Is this why CPU and GPU are giving different results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, I saw this in the makani/physicsnemo repos and I copied it blindly (testing with it and without didn't really make any difference, or at least I didn't notice it)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don’t have spatial parallelism on cpu, so it wouldn’t cause issues for them, but we should remove it. I did in my branch incorporating this code.

def forward(
ctx,
input: torch.Tensor,
group: torch.distributed.ProcessGroup,
) -> torch.Tensor:
output = input.clone()
torch.distributed.all_reduce(output, group=group)
return output

@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clone(), None
Comment on lines +66 to +73
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cloning is probably unnecessary here


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

identity in backward may become an issue... we may need an all-reduce in backward, but that's kind of what the hook below is doing


class ModelTorchDistributed(DistributedBackend):
"""Distributed backend with spatial model parallelism.

Expand Down Expand Up @@ -307,31 +337,73 @@ def _device_ids(self) -> list[int] | None:
def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
"""Wrap with DDP over the **data** process group.

For now, we assume spatial communication is expected to be handled
inside the model layers themselves. If we need to change course, we
can revisit...
Spatial model parallelism is handled by:
- Forward: communication inside model layers (distributed SHT/iSHT)
- Backward: gradient hooks registered here that all-reduce across
spatial ranks, so every rank sees the global-mean gradient.

``broadcast_buffers=False`` is required because the SHT/iSHT layers
store precomputed Legendre polynomial buffers. DDP's default
buffer broadcast modifies these in-place between forward calls,
which breaks autograd's tensor-version tracking.
"""
if any(p.requires_grad for p in module.parameters()):
if using_gpu():
output_device = [self._device_id]
else:
output_device = None
return DistributedDataParallel(
wrapped = DistributedDataParallel(
SyncBatchNorm.convert_sync_batchnorm(module),
device_ids=self._device_ids,
output_device=output_device,
process_group=self._data_group,
broadcast_buffers=False,
)
self._register_spatial_grad_hooks(wrapped)
return wrapped
return DummyWrapper(module)

def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None:
"""All-reduce gradients across spatial ranks after each backward.

Each spatial rank only sees its local slice of the input, so its
gradient is a partial sum. This hook sums those partials so
that every rank applies the same weight update.

The hook fires via ``register_hook`` on each parameter, which is
invoked with the per-backward gradient tensor before it is
accumulated into ``.grad`` and before DDP's data-parallel
all-reduce. The two reductions commute (orthogonal groups), so
ordering does not matter.
"""
if self._h_size <= 1 and self._w_size <= 1:
return
spatial_group = self._spatial_group

def _hook(grad: torch.Tensor) -> torch.Tensor:
if grad is None:
return grad

reduced = grad.contiguous().clone()
torch.distributed.all_reduce(reduced, group=spatial_group)

# If we want mean gradient instead of sum, we want:
# reduced /= (self._h_size * self._w_size)
Comment on lines +391 to +392
Copy link
Contributor Author

@mahf708 mahf708 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't quite wrap my head around which one we really need tbh, and I think this is linked with how we do losses here...


return reduced

for p in module.parameters():
if p.requires_grad:
p.register_hook(_hook)

def barrier(self):
"""Global barrier across all ranks."""
logger.debug("Barrier on rank %d", self._rank)
torch.distributed.barrier(device_ids=self._device_ids)

def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor:
if self._h_size > 1 or self._w_size > 1:
torch.distributed.all_reduce(tensor, group=self._spatial_group)
return _AutogradAllReduce.apply(tensor, self._spatial_group)
return tensor

def weighted_mean(
Expand Down