add backward pass for spatial prallellism#993
Conversation
use graph-aware all_reduce inside spatial mean
| atol=1e-2, | ||
| rtol=1e-2, |
There was a problem hiding this comment.
seems too low, and thus I have low confidence in all of this for now
| output = input.clone() | ||
| torch.distributed.all_reduce(output, group=group) | ||
| return output | ||
|
|
||
| @staticmethod | ||
| @custom_bwd(device_type="cuda") | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| return grad_output.clone(), None |
There was a problem hiding this comment.
the cloning is probably unnecessary here
| # If we want mean gradient instead of sum, we want: | ||
| # reduced /= (self._h_size * self._w_size) |
There was a problem hiding this comment.
I can't quite wrap my head around which one we really need tbh, and I think this is linked with how we do losses here...
| @custom_bwd(device_type="cuda") | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| return grad_output.clone(), None | ||
|
|
There was a problem hiding this comment.
identity in backward may become an issue... we may need an all-reduce in backward, but that's kind of what the hook below is doing
backwards step Add test that verifies consistency between NonDistribute and TorchModelDistributed for loss and gradient calculation using simple SHT/iSHT transforms
| if not dist.is_root(): | ||
| return | ||
|
|
||
| if not BASELINE_FILE.exists(): |
There was a problem hiding this comment.
Please use the regression helper in fme/core/testing/regression.py for this, to reduce duplication of this kind of regression logic. It should be general-purpose and support this use case.
There was a problem hiding this comment.
A "bug" here is that this will pass when no baseline exists.
| TIMESTEP = datetime.timedelta(hours=6) | ||
|
|
||
|
|
||
| def get_dataset_info( |
There was a problem hiding this comment.
This is a lot of duplicated testing logic, do we need a new testing file or can we make the existing test in test_single_module.py parallel-enabled / put this parallel-enabled test in that file to use its helpers?
There was a problem hiding this comment.
That test uses the legacy SFNO --- could we change it to use CSFNO?
There was a problem hiding this comment.
I'd like to avoid new ways of executing tests, if we need something can we put this in the Makefile, or if it's temporary for this PR can you leave it uncommitted? I think the current make targets support this in two executions (in serial and then in parallel), using the TEST_PATH argument.
mcgibbon
left a comment
There was a problem hiding this comment.
It seems like the core contribution of this script is that it adds proper gradient reductions to spatial_reduce_sum. It would be nice to see a unit test added that fails before you add those changes and passes after you add those changes. I think we could and should merge a PR with just those changes.
It feels like the large rtol and atol are needed because the test isn't really passing, and that we'll need some finer-grained tests to tell us what part of the code is causing the failure. Splitting the spatial_reduce_sum changes out will let us get them merged without blocking on this, which could take a while and will need a lot more test code to be written/reviewed.
Add _AutogradAllReduce, a torch.autograd.Function that wraps all_reduce with an identity backward, replacing the in-place all_reduce in spatial_reduce_sum that broke the autograd graph. Add spatial gradient hooks in wrap_module that all-reduce parameter gradients across spatial ranks after backward, so each rank applies the same weight update. Also set broadcast_buffers=False in DDP to prevent corruption of SHT/iSHT Legendre polynomial buffers. Based on work by mahf708 and peterdschwartz in E3SM-Project/ace PR #993. Co-Authored-By: mahf708 <naser.mahfouz@pnnl.gov> Co-Authored-By: peterdschwartz <peterdschwartz83@gmail.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@mcgibbon note that large tol is only need for GPU/CPU inconsistency, but it is ok otherwise (i.e., it is ok if we run everything on cpu or gpu --- that may simplify the hunt or be reassiuring? idk) also, @mcgibbon, thoguhts on this potential gotcha: besides the forward ops that we want to be differentiable, etc., I don't think we want this thing to be differentiable, right? If so, we can potentially change this to use a bare all_reduce here. What do you think? I think this is largely ok for forward, but if we ever add backward to the spatial_reduce_sum, it may cause problems? def gather_spatial(
self, data: dict[str, torch.Tensor], img_shape: tuple[int, int]
) -> dict[str, torch.Tensor]:
"""Gather local spatial chunks back to global tensors via all-reduce."""
return {k: self.gather_spatial_tensor(v, img_shape) for k, v in data.items()}
def gather_spatial_tensor(
self, tensor: torch.Tensor, img_shape: tuple[int, int]
) -> torch.Tensor:
"""Reassemble a spatially-sharded tensor on every rank via all-reduce.
Args:
tensor: Local spatial shard.
img_shape: Global ``(H, W)`` spatial dimensions.
"""
if img_shape == tensor.shape[-2:]:
return tensor
global_shape = (*tensor.shape[:-2], *img_shape)
slices = self.get_local_slices(img_shape)
buf = torch.zeros(global_shape, dtype=tensor.dtype, device=tensor.device)
buf[(..., *slices)] = tensor
return self.spatial_reduce_sum(buf) |
| """ | ||
|
|
||
| @staticmethod | ||
| @custom_fwd(device_type="cuda") |
There was a problem hiding this comment.
Why is custom_fwd needed? Don't you need this on CPU? Is this why CPU and GPU are giving different results?
There was a problem hiding this comment.
No idea, I saw this in the makani/physicsnemo repos and I copied it blindly (testing with it and without didn't really make any difference, or at least I didn't notice it)
There was a problem hiding this comment.
They don’t have spatial parallelism on cpu, so it wouldn’t cause issues for them, but we should remove it. I did in my branch incorporating this code.
We should be able to get pretty close results on both. Just make sure that initializations always happen on CPU, and then get transferred to get_device(), rather than being initialized directly on-device. I marked one location that might be causing this difference with a line comment.
Does your added code make anything slower? I don't really expect so? When we are in non-differentiable places we use a no_grad context to handle not running extra logic. I'd say don't worry about using the differentiable version. |
add backward pass for spatial prallellism with three changes:
there might be some assumptions about what types of loss we are supporting thus far, and a careful auditing will be required. An alternative to this formulation would be an explicit and clear spatial awarness in the loss calculation elsewhere ... but that may be invasive?