Skip to content

add backward pass for spatial prallellism#993

Open
mahf708 wants to merge 5 commits intoai2cm:mainfrom
E3SM-Project:spatial-parallel-training
Open

add backward pass for spatial prallellism#993
mahf708 wants to merge 5 commits intoai2cm:mainfrom
E3SM-Project:spatial-parallel-training

Conversation

@mahf708
Copy link
Contributor

@mahf708 mahf708 commented Mar 19, 2026

add backward pass for spatial prallellism with three changes:

  • an autograd version of spatial_reduce_sum to deal with the loss path
  • hook up the parameters with a grad hook
  • setting broadcast_buffers to False, otherwise corrupting/mutating in-place

there might be some assumptions about what types of loss we are supporting thus far, and a careful auditing will be required. An alternative to this formulation would be an explicit and clear spatial awarness in the loss calculation elsewhere ... but that may be invasive?

  • Tests added

Comment on lines +217 to +218
atol=1e-2,
rtol=1e-2,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems too low, and thus I have low confidence in all of this for now

Comment on lines +65 to +72
output = input.clone()
torch.distributed.all_reduce(output, group=group)
return output

@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clone(), None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cloning is probably unnecessary here

Comment on lines +390 to +391
# If we want mean gradient instead of sum, we want:
# reduced /= (self._h_size * self._w_size)
Copy link
Contributor Author

@mahf708 mahf708 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't quite wrap my head around which one we really need tbh, and I think this is linked with how we do losses here...

@custom_bwd(device_type="cuda")
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clone(), None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

identity in backward may become an issue... we may need an all-reduce in backward, but that's kind of what the hook below is doing

backwards step

Add test that verifies consistency between NonDistribute and
TorchModelDistributed for loss and gradient calculation using simple
SHT/iSHT transforms
if not dist.is_root():
return

if not BASELINE_FILE.exists():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the regression helper in fme/core/testing/regression.py for this, to reduce duplication of this kind of regression logic. It should be general-purpose and support this use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A "bug" here is that this will pass when no baseline exists.

TIMESTEP = datetime.timedelta(hours=6)


def get_dataset_info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of duplicated testing logic, do we need a new testing file or can we make the existing test in test_single_module.py parallel-enabled / put this parallel-enabled test in that file to use its helpers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That test uses the legacy SFNO --- could we change it to use CSFNO?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to avoid new ways of executing tests, if we need something can we put this in the Makefile, or if it's temporary for this PR can you leave it uncommitted? I think the current make targets support this in two executions (in serial and then in parallel), using the TEST_PATH argument.

Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the core contribution of this script is that it adds proper gradient reductions to spatial_reduce_sum. It would be nice to see a unit test added that fails before you add those changes and passes after you add those changes. I think we could and should merge a PR with just those changes.

It feels like the large rtol and atol are needed because the test isn't really passing, and that we'll need some finer-grained tests to tell us what part of the code is causing the failure. Splitting the spatial_reduce_sum changes out will let us get them merged without blocking on this, which could take a while and will need a lot more test code to be written/reviewed.

mcgibbon added a commit that referenced this pull request Mar 20, 2026
Add _AutogradAllReduce, a torch.autograd.Function that wraps
all_reduce with an identity backward, replacing the in-place
all_reduce in spatial_reduce_sum that broke the autograd graph.

Add spatial gradient hooks in wrap_module that all-reduce parameter
gradients across spatial ranks after backward, so each rank applies
the same weight update. Also set broadcast_buffers=False in DDP to
prevent corruption of SHT/iSHT Legendre polynomial buffers.

Based on work by mahf708 and peterdschwartz in E3SM-Project/ace PR #993.

Co-Authored-By: mahf708 <naser.mahfouz@pnnl.gov>
Co-Authored-By: peterdschwartz <peterdschwartz83@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mahf708
Copy link
Contributor Author

mahf708 commented Mar 20, 2026

@mcgibbon note that large tol is only need for GPU/CPU inconsistency, but it is ok otherwise (i.e., it is ok if we run everything on cpu or gpu --- that may simplify the hunt or be reassiuring? idk)

also, @mcgibbon, thoguhts on this potential gotcha: besides the forward ops that we want to be differentiable, etc., I don't think we want this thing to be differentiable, right? If so, we can potentially change this to use a bare all_reduce here. What do you think? I think this is largely ok for forward, but if we ever add backward to the spatial_reduce_sum, it may cause problems?

    def gather_spatial(
        self, data: dict[str, torch.Tensor], img_shape: tuple[int, int]
    ) -> dict[str, torch.Tensor]:
        """Gather local spatial chunks back to global tensors via all-reduce."""
        return {k: self.gather_spatial_tensor(v, img_shape) for k, v in data.items()}

    def gather_spatial_tensor(
        self, tensor: torch.Tensor, img_shape: tuple[int, int]
    ) -> torch.Tensor:
        """Reassemble a spatially-sharded tensor on every rank via all-reduce.

        Args:
            tensor: Local spatial shard.
            img_shape: Global ``(H, W)`` spatial dimensions.
        """
        if img_shape == tensor.shape[-2:]:
            return tensor
        global_shape = (*tensor.shape[:-2], *img_shape)
        slices = self.get_local_slices(img_shape)
        buf = torch.zeros(global_shape, dtype=tensor.dtype, device=tensor.device)
        buf[(..., *slices)] = tensor
        return self.spatial_reduce_sum(buf)

"""

@staticmethod
@custom_fwd(device_type="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is custom_fwd needed? Don't you need this on CPU? Is this why CPU and GPU are giving different results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, I saw this in the makani/physicsnemo repos and I copied it blindly (testing with it and without didn't really make any difference, or at least I didn't notice it)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don’t have spatial parallelism on cpu, so it wouldn’t cause issues for them, but we should remove it. I did in my branch incorporating this code.

@mcgibbon
Copy link
Contributor

@mcgibbon note that large tol is only need for GPU/CPU inconsistency, but it is ok otherwise (i.e., it is ok if we run everything on cpu or gpu --- that may simplify the hunt or be reassiuring? idk)

We should be able to get pretty close results on both. Just make sure that initializations always happen on CPU, and then get transferred to get_device(), rather than being initialized directly on-device. I marked one location that might be causing this difference with a line comment.

also, @mcgibbon, thoguhts on this potential gotcha: besides the forward ops that we want to be differentiable, etc., I don't think we want this thing to be differentiable, right? If so, we can potentially change this to use a bare all_reduce here. What do you think? I think this is largely ok for forward, but if we ever add backward to the spatial_reduce_sum, it may cause problems?

    def gather_spatial(
        self, data: dict[str, torch.Tensor], img_shape: tuple[int, int]
    ) -> dict[str, torch.Tensor]:
        """Gather local spatial chunks back to global tensors via all-reduce."""
        return {k: self.gather_spatial_tensor(v, img_shape) for k, v in data.items()}

    def gather_spatial_tensor(
        self, tensor: torch.Tensor, img_shape: tuple[int, int]
    ) -> torch.Tensor:
        """Reassemble a spatially-sharded tensor on every rank via all-reduce.

        Args:
            tensor: Local spatial shard.
            img_shape: Global ``(H, W)`` spatial dimensions.
        """
        if img_shape == tensor.shape[-2:]:
            return tensor
        global_shape = (*tensor.shape[:-2], *img_shape)
        slices = self.get_local_slices(img_shape)
        buf = torch.zeros(global_shape, dtype=tensor.dtype, device=tensor.device)
        buf[(..., *slices)] = tensor
        return self.spatial_reduce_sum(buf)

Does your added code make anything slower? I don't really expect so? When we are in non-differentiable places we use a no_grad context to handle not running extra logic. I'd say don't worry about using the differentiable version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants