Skip to content

Add spatial parallel regression tests and fix differentiable reduce#995

Merged
mcgibbon merged 8 commits intomainfrom
feature/parallel-regression-coverage
Mar 20, 2026
Merged

Add spatial parallel regression tests and fix differentiable reduce#995
mcgibbon merged 8 commits intomainfrom
feature/parallel-regression-coverage

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Mar 20, 2026

Fix spatial parallelism backward pass and add regression test infrastructure.

spatial_reduce_sum used an in-place all_reduce that broke the autograd graph,
preventing gradients from flowing through the loss computation path. Spatial ranks
also did not aggregate parameter gradients, so each rank applied different weight
updates.

Changes:

  • fme.core.distributed.model_torch_distributed._AutogradAllReduce: new
    torch.autograd.Function wrapping all_reduce with identity backward, making
    spatial_reduce_sum differentiable
  • fme.core.distributed.model_torch_distributed.ModelTorchDistributed.wrap_module:
    register per-parameter gradient hooks that all-reduce across spatial ranks;
    set broadcast_buffers=False to protect SHT/iSHT Legendre buffers
  • fme.core.distributed.parallel_tests.test_regression: new parameterized
    regression test framework (RegressionCase base class) that validates
    forward → backward → forward correctness across spatial decompositions
  • AGENTS.md: document parallel test commands, env vars, and baseline workflow

mcgibbon and others added 3 commits March 20, 2026 16:26
Introduces test_regression.py with a RegressionCase base class that
defines initialize (data + module), reduce (default: spatial-aware sum),
and lr (for SGD). The test does forward -> backward -> SGD step ->
forward and compares all outputs against a single-rank baseline,
catching gradient bugs that Adam-based tests mask.

Includes a linear (Conv2d 1x1) case that currently catches the known
gradient hook bug under spatial parallelism.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add _AutogradAllReduce, a torch.autograd.Function that wraps
all_reduce with an identity backward, replacing the in-place
all_reduce in spatial_reduce_sum that broke the autograd graph.

Add spatial gradient hooks in wrap_module that all-reduce parameter
gradients across spatial ranks after backward, so each rank applies
the same weight update. Also set broadcast_buffers=False in DDP to
prevent corruption of SHT/iSHT Legendre polynomial buffers.

Based on work by mahf708 and peterdschwartz in E3SM-Project/ace PR #993.

Co-Authored-By: mahf708 <[email protected]>
Co-Authored-By: peterdschwartz <[email protected]>
Co-Authored-By: Claude Opus 4.6 <[email protected]>
@climate-ci-github climate-ci-github changed the title Feature/parallel regression coverage Add spatial parallel regression tests and fix differentiable reduce Mar 20, 2026
@mcgibbon mcgibbon marked this pull request as ready for review March 20, 2026 17:17
"""

@staticmethod
@custom_fwd(device_type="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this strictly a GPU test? Does it get skipped on CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I had fixed this locally but not pushed. Now it's pushed.

"""1x1 Conv2d applied channel-wise; verifies basic gradient flow."""

n_channels: int = 4
_img_shape: tuple[int, int] = (8, 16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] can you just have this as img_shape and not have a property?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +45 to +46
torchrun. Generating baselines under the same backend you test against does
not validate cross-backend correctness.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I don't understand this comment to the agent, we don't allow the agent to generate or delete the baseline .pt right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, or at least I do - if the file is deleted (which you can let it do on a per-command basis) and the test is re-run, the file gets re-generated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if Claude adds a new test, it needs to generate the baseline pt file.

Comment on lines +388 to +389
# If we want mean gradient instead of sum, we want:
# reduced /= (self._h_size * self._w_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always want the sum right? Doesn't each rank get the partial gradient that is already area weighted, so summing them should be the correct mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, deleted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not always, right? Some averages aren't weighted... or maybe I misread the code. For the production ones, like AreaWeightedMSE, we are covered

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it in the other PR because I wasn't 100% sure which losses are used and which ones etc. and what general case we should support

mcgibbon and others added 2 commits March 20, 2026 19:10
custom_fwd/custom_bwd are for AMP dtype management and not needed
here since all_reduce and identity don't require dtype casting.
The clone() in backward is also unnecessary since there's no
in-place mutation.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Remove _img_shape indirection in _LinearCase dataclass - a plain
attribute satisfies the abstract property. Remove the commented-out
mean gradient alternative since sum is always correct here (each
rank's gradient is already a partial sum of area-weighted values).

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@mcgibbon mcgibbon enabled auto-merge (squash) March 20, 2026 20:11
@mcgibbon mcgibbon merged commit 7943fe4 into main Mar 20, 2026
7 checks passed
@mcgibbon mcgibbon deleted the feature/parallel-regression-coverage branch March 20, 2026 20:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants