Add spatial parallel regression tests and fix differentiable reduce#995
Add spatial parallel regression tests and fix differentiable reduce#995
Conversation
Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduces test_regression.py with a RegressionCase base class that defines initialize (data + module), reduce (default: spatial-aware sum), and lr (for SGD). The test does forward -> backward -> SGD step -> forward and compares all outputs against a single-rank baseline, catching gradient bugs that Adam-based tests mask. Includes a linear (Conv2d 1x1) case that currently catches the known gradient hook bug under spatial parallelism. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add _AutogradAllReduce, a torch.autograd.Function that wraps all_reduce with an identity backward, replacing the in-place all_reduce in spatial_reduce_sum that broke the autograd graph. Add spatial gradient hooks in wrap_module that all-reduce parameter gradients across spatial ranks after backward, so each rank applies the same weight update. Also set broadcast_buffers=False in DDP to prevent corruption of SHT/iSHT Legendre polynomial buffers. Based on work by mahf708 and peterdschwartz in E3SM-Project/ace PR #993. Co-Authored-By: mahf708 <[email protected]> Co-Authored-By: peterdschwartz <[email protected]> Co-Authored-By: Claude Opus 4.6 <[email protected]>
| """ | ||
|
|
||
| @staticmethod | ||
| @custom_fwd(device_type="cuda") |
There was a problem hiding this comment.
Is this strictly a GPU test? Does it get skipped on CPU?
There was a problem hiding this comment.
Ah yes I had fixed this locally but not pushed. Now it's pushed.
| """1x1 Conv2d applied channel-wise; verifies basic gradient flow.""" | ||
|
|
||
| n_channels: int = 4 | ||
| _img_shape: tuple[int, int] = (8, 16) |
There was a problem hiding this comment.
[nit] can you just have this as img_shape and not have a property?
| torchrun. Generating baselines under the same backend you test against does | ||
| not validate cross-backend correctness. |
There was a problem hiding this comment.
[nit] I don't understand this comment to the agent, we don't allow the agent to generate or delete the baseline .pt right?
There was a problem hiding this comment.
We do, or at least I do - if the file is deleted (which you can let it do on a per-command basis) and the test is re-run, the file gets re-generated.
There was a problem hiding this comment.
Also, if Claude adds a new test, it needs to generate the baseline pt file.
| # If we want mean gradient instead of sum, we want: | ||
| # reduced /= (self._h_size * self._w_size) |
There was a problem hiding this comment.
We always want the sum right? Doesn't each rank get the partial gradient that is already area weighted, so summing them should be the correct mean?
There was a problem hiding this comment.
Not always, right? Some averages aren't weighted... or maybe I misread the code. For the production ones, like AreaWeightedMSE, we are covered
There was a problem hiding this comment.
I kept it in the other PR because I wasn't 100% sure which losses are used and which ones etc. and what general case we should support
custom_fwd/custom_bwd are for AMP dtype management and not needed here since all_reduce and identity don't require dtype casting. The clone() in backward is also unnecessary since there's no in-place mutation. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Remove _img_shape indirection in _LinearCase dataclass - a plain attribute satisfies the abstract property. Remove the commented-out mean gradient alternative since sum is always correct here (each rank's gradient is already a partial sum of area-weighted values). Co-Authored-By: Claude Opus 4.6 <[email protected]>
Fix spatial parallelism backward pass and add regression test infrastructure.
spatial_reduce_sumused an in-placeall_reducethat broke the autograd graph,preventing gradients from flowing through the loss computation path. Spatial ranks
also did not aggregate parameter gradients, so each rank applied different weight
updates.
Changes:
fme.core.distributed.model_torch_distributed._AutogradAllReduce: newtorch.autograd.Functionwrappingall_reducewith identity backward, makingspatial_reduce_sumdifferentiablefme.core.distributed.model_torch_distributed.ModelTorchDistributed.wrap_module:register per-parameter gradient hooks that all-reduce across spatial ranks;
set
broadcast_buffers=Falseto protect SHT/iSHT Legendre buffersfme.core.distributed.parallel_tests.test_regression: new parameterizedregression test framework (
RegressionCasebase class) that validatesforward → backward → forward correctness across spatial decompositions
AGENTS.md: document parallel test commands, env vars, and baseline workflow