From 059de3c6a4cb2510bf46b220932e3a12c93c6a6b Mon Sep 17 00:00:00 2001 From: misko Date: Fri, 24 Oct 2025 22:44:29 +0000 Subject: [PATCH 01/13] update gputils --- src/fairchem/core/common/gp_utils.py | 330 ++++++++++++--------------- tests/core/common/test_gp_utils.py | 135 ++++------- 2 files changed, 182 insertions(+), 283 deletions(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 48d2226cb8..3a3652cf69 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -9,12 +9,11 @@ import contextlib import logging -from typing import Any +import threading -import numpy as np import torch from torch import distributed as dist -from torch.distributed.nn.functional import all_reduce +from torch.distributed.nn.functional import all_reduce, reduce_scatter """ Functions to support graph parallel training. @@ -27,6 +26,32 @@ _GRAPH_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None +_tls = threading.local() + + +def pad_input(input: torch.Tensor, padded_size: int): + # pad using functional + # if input.shape[0]!=padded_size: + # input=torch.nn.functional.pad(input,(0,0,0,0,0,1)).contiguous() + + # pad using manual tensor cat + if input.shape[0] != padded_size: + input = torch.cat( + [ + input, + torch.zeros( + (padded_size - input.shape[0], *input.shape[1:]), + device=input.device, + dtype=input.dtype, + ), + ], + dim=0, + ) + + assert input.shape[0] == padded_size + + return input + def ensure_div(a: int, b: int) -> None: assert a % b == 0 @@ -149,190 +174,153 @@ def get_gp_world_size() -> int: ########## DIST METHODS ########## -def pad_tensor( - tensor: torch.Tensor, dim: int = -1, target_size: int | None = None -) -> torch.Tensor: - size = tensor.size(dim) - if target_size is None: - world_size = get_gp_world_size() - pad_size = 0 if size % world_size == 0 else world_size - size % world_size - else: - pad_size = target_size - size - if pad_size == 0: - return tensor - pad_shape = list(tensor.shape) - pad_shape[dim] = pad_size - padding = torch.empty(pad_shape, device=tensor.device, dtype=tensor.dtype) - return torch.cat([tensor, padding], dim=dim) - - -def trim_tensor(tensor: torch.Tensor, sizes: torch.Tensor | None = None, dim: int = 0): - size = tensor.size(dim) - world_size = get_gp_world_size() - if size % world_size == 0: - return tensor, sizes - trim_size = size - size % world_size - if dim == 0: - tensor = tensor[:trim_size] - elif dim == 1: - tensor = tensor[:, :trim_size] - else: - raise ValueError - if sizes is not None: - sizes[-1] = sizes[-1] - size % world_size - return tensor, sizes - - -def _tensor_to_split_partitions(tensor: torch.Tensor, dim: int = -1): - group = get_gp_group() - num_parts = dist.get_world_size(group=group) - return [len(part) for part in np.array_split(np.zeros(tensor.size(dim)), num_parts)] - - -def _split_tensor( - tensor: torch.Tensor, - dim: int = -1, - contiguous_chunks: bool = False, -): - tensor_list = torch.split(tensor, _tensor_to_split_partitions(tensor, dim), dim=dim) - if contiguous_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - return tensor_list - - -def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor: - group = get_gp_group() - if ctx: - ctx.mark_dirty(input) - dist.all_reduce(input, group=group) - return input - - -def _split(input: torch.Tensor, dim: int = -1) -> torch.Tensor: - rank = get_gp_rank() - input_list = _split_tensor(input, dim=dim) - return input_list[rank].clone().contiguous() - - -def _gather(input: torch.Tensor, dim: int = -1) -> torch.Tensor: - group = get_gp_group() - rank = get_gp_rank() - world_size = dist.get_world_size(group=group) - if world_size == 1: - return input - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input - dist.all_gather(tensor_list, input, group=group) - return torch.cat(tensor_list, dim=dim).contiguous() - - -def _gather_with_padding(input: torch.Tensor, dim: int = -1) -> torch.Tensor: - group = get_gp_group() - rank = get_gp_rank() - world_size = dist.get_world_size(group=group) - if world_size == 1: - return input - - # Gather sizes - size_list = [ - torch.empty(1, device=input.device, dtype=torch.long) for _ in range(world_size) - ] - size = torch.tensor([input.size(dim)], device=input.device, dtype=torch.long) - size_list[rank] = size - dist.all_gather(size_list, size, group=group) - - # Gather the inputs - max_size = int(max([size.item() for size in size_list])) - input = pad_tensor(input, dim, max_size) - shape = list(input.shape) - shape[dim] = max_size - tensor_list = [ - torch.empty(shape, device=input.device, dtype=input.dtype) - for _ in range(world_size) - ] - - dist.all_gather(tensor_list, input, group=group) - tensor_list[rank] = input # pop back in our local copy (requires grad) - - # Trim and cat - return torch.cat( - [tensor.narrow(dim, 0, size) for tensor, size in zip(tensor_list, size_list)], - dim=dim, - ).contiguous() - +def size_list_fn(size, parts): + return [size // parts + (1 if idx < size % parts else 0) for idx in range(parts)] -class CopyToModelParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, input: torch.Tensor) -> torch.Tensor: - return input - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - return _reduce(None, grad_output) +def reduce_from_model_parallel_region(input: torch.Tensor) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + return ReduceFromModelParallelRegion.apply(input) class ReduceFromModelParallelRegion(torch.autograd.Function): @staticmethod + @torch.compiler.disable def forward(ctx, input: torch.Tensor) -> torch.Tensor: # return _reduce(ctx, input) # this operates in place return all_reduce(input, group=get_gp_group()) # this operats out of place @staticmethod + @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return grad_output +def scatter_to_model_parallel_region(input: torch.Tensor) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + return ScatterToModelParallelRegion.apply(input) + + # this returns the values in place class ScatterToModelParallelRegion(torch.autograd.Function): @staticmethod + @torch.compiler.disable def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - result = _split(input, dim) - ctx.save_for_backward(torch.tensor(dim)) - return result + ctx.split_sizes = size_list_fn(input.shape[0], get_gp_world_size()) + return input.split(ctx.split_sizes)[get_gp_rank()] @staticmethod + @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - return _gather_with_padding(grad_output.clone(), dim.item()), None + return gather_from_model_parallel_region_sum_grad( + grad_output, sum(ctx.split_sizes), False + ) -class GatherFromModelParallelRegion(torch.autograd.Function): +def gather_from_model_parallel_region_sum_grad( + input: torch.Tensor, + natoms: int, + gloo_backend: bool, +) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + world_size = get_gp_world_size() + size_list = size_list_fn(natoms, world_size) + + input = pad_input( + input, natoms // world_size + (1 if natoms % world_size != 0 else 0) + ) + + if gloo_backend: + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPaddedGLOO.apply( + input, False + ) + else: + # tensor_list_w_padding = all_gather(input, group=get_gp_group()) + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply( + input, False + ) + + return torch.cat( + [ + t.narrow(0, 0, s) if t.shape[0] != s else t + for t, s in zip(tensor_list_w_padding, size_list) + ], + dim=0, + ) + + +def gather_from_model_parallel_region_sum_grad_async( + input: torch.Tensor, natoms: int +) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + world_size = get_gp_world_size() + + input = pad_input( + input, natoms // world_size + (1 if natoms % world_size != 0 else 0) + ) + + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply( + input, True + ) + handle = _tls.handle + return tensor_list_w_padding, handle + + +class GatherFromModelParallelRegionSumGradPadded(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - ctx.save_for_backward(torch.tensor(dim)) - return _gather_with_padding(input, dim) + @torch.compiler.disable + def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + ctx.rank = get_gp_rank() + ctx.group = get_gp_group() + tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] + if async_op: + _tls.handle = dist.all_gather( + tensor_list, input, group=ctx.group, async_op=async_op + ) + else: + dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + return tuple(tensor_list) @staticmethod - def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - result = _split(grad_output, dim.item()) - return result, None + @torch.compiler.disable + def backward(ctx, *grad_outputs): + local_grad_output = grad_outputs[ctx.rank] + output_tensor = torch.empty_like(local_grad_output) + return reduce_scatter(output_tensor, grad_outputs, group=ctx.group), None -class GatherFromModelParallelRegionSumGrad(torch.autograd.Function): +class GatherFromModelParallelRegionSumGradPaddedGLOO(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - ctx.save_for_backward(torch.tensor(dim)) - return _gather_with_padding(input, dim) + @torch.compiler.disable + def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + ctx.rank = get_gp_rank() + ctx.group = get_gp_group() + ctx.shape = input.shape + tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] + if async_op: + _tls.handle = dist.all_gather( + tensor_list, input, group=ctx.group, async_op=async_op + ) + else: + dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + return tuple(tensor_list) @staticmethod - def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - group = get_gp_group() - # use dist internal # does not work - # reduced_grad_output = grad_output.clone() - # dist.all_reduce( - # reduced_grad_output, group=group - # ) # This is an inplace operation - # grad_output = reduced_grad_output - - # use functional version instead - grad_output = all_reduce(grad_output, group=group) - - result = _split(grad_output, dim.item()) + @torch.compiler.disable + def backward(ctx, *grad_outputs): + grad_output = all_reduce(torch.cat(grad_outputs, dim=0), group=ctx.group) + ctx.padded_size = grad_outputs[0].shape[0] + result = grad_output[ + ctx.padded_size * ctx.rank : ctx.padded_size * ctx.rank + ctx.shape[0] + ] return result, None +def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + return ScaleBackwardGrad.apply(input) + + # Leave forward untouched but upscale the gradient by a factor of gp_group_size # DDP reduces a mean across the loss, if we have gp_group_size=2 and 6 ranks # that means we do (a_1+a_2+a_3+b_1+b_2+b_3)/6 in ddp mean. This gets us the @@ -344,45 +332,11 @@ def backward(ctx, grad_output: torch.Tensor): # avoid over head communication class ScaleBackwardGrad(torch.autograd.Function): @staticmethod + @torch.compiler.disable def forward(ctx, input: torch.Tensor) -> torch.Tensor: return input @staticmethod + @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor): return dist.get_world_size(get_gp_group()) * grad_output - - -def copy_to_model_parallel_region(input: torch.Tensor) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return CopyToModelParallelRegion.apply(input) - - -def reduce_from_model_parallel_region(input: torch.Tensor) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return ReduceFromModelParallelRegion.apply(input) - - -def scatter_to_model_parallel_region( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return ScatterToModelParallelRegion.apply(input, dim) - - -def gather_from_model_parallel_region( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return GatherFromModelParallelRegion.apply(input, dim) - - -def gather_from_model_parallel_region_sum_grad( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return GatherFromModelParallelRegionSumGrad.apply(input, dim) - - -def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return ScaleBackwardGrad.apply(input) diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index de4d81d1b5..f396d302eb 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -14,9 +14,9 @@ from fairchem.core.common import gp_utils from fairchem.core.common.gp_utils import ( - gather_from_model_parallel_region, gather_from_model_parallel_region_sum_grad, scatter_to_model_parallel_region, + size_list_fn, ) from fairchem.core.common.test_utils import ( PGConfig, @@ -83,9 +83,9 @@ def test_scatter_tensors( assert torch.equal(out, expected_out) -def scatter_gather_fn(input: torch.Tensor, dim: int = 0): - x = scatter_to_model_parallel_region(input, dim) - return gather_from_model_parallel_region(x, dim) +def scatter_gather_fn(input: torch.Tensor): + x = scatter_to_model_parallel_region(input) + return gather_from_model_parallel_region_sum_grad(x, input.shape[0], True) @pytest.mark.parametrize( @@ -133,7 +133,7 @@ def test_gather_tensors( def scatter_bwd_test(): rank = dist.get_rank() x_full = torch.tensor([2, 3, 5, 7], requires_grad=True, dtype=torch.float) - x = scatter_to_model_parallel_region(x_full, 0) + x = scatter_to_model_parallel_region(x_full) energy_part = x.prod() ** 2 @@ -202,68 +202,13 @@ def test_scatter_bwd(): compare_and_assert_dict(expected_output[results["gp_rank"]], results) -def gather_bwd_test(rank=-1): - if rank < 0: - rank = dist.get_rank() - x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) - x_full = gather_from_model_parallel_region(x, 0) - else: - x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) - x_other = torch.tensor([(1 - rank) + 2], requires_grad=True, dtype=torch.float) - x_full = torch.cat([x, x_other]) if rank == 0 else torch.cat([x_other, x]) - - energy_part = (x_full.prod() + rank + 1) ** 2 - - forces_part = torch.autograd.grad( - [energy_part], - [x], - create_graph=True, - )[0] - - dforces_dinput_part = torch.autograd.grad( - [forces_part], - [x], - create_graph=True, - )[0] - - return { - "gp_rank": rank, - "energy": energy_part.detach(), - "forces": forces_part.detach(), - "dforces_dinput": dforces_dinput_part.detach(), - } - - -def test_gather_bwd(): - # A | B - # E_0 = (A*B +1)^2 , E_1 = (A*B+2)^2 - # = 49 = 64 - # dL_0/dA = 2*A*B^2+2*B = 42 - # dL_1/dB = 2*A^2*B+4*A = 32 - # dL_0/dB and dL_1/dA are not used! see test_gather_sum_bwd!! - # d^2L_1/dA^2 = 2*B^2 = 18 - # d^2L_1/dB^2 = 2*A^2 = 8 - - non_gp_results_by_gp_rank = {0: gather_bwd_test(0), 1: gather_bwd_test(1)} - - config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - all_rank_results = spawn_multi_process( - config, - gather_bwd_test, - init_pg_and_rank_and_launch_test, - ) - - for rank_results in all_rank_results: - compare_and_assert_dict( - non_gp_results_by_gp_rank[rank_results["gp_rank"]], rank_results - ) - - def gather_sum_bwd_test(rank=-1): if rank < 0: rank = dist.get_rank() x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) - x_full = gather_from_model_parallel_region_sum_grad(x, 0) + x_full = gather_from_model_parallel_region_sum_grad( + x, gp_utils.get_gp_world_size(), True + ) energy = (x_full.prod() + rank + 1) ** 2 # sum energy = gp_utils.reduce_from_model_parallel_region(energy) @@ -345,7 +290,7 @@ def scatter_prod_reduce(all_inputs): x_full = all_inputs.clone() - x = scatter_to_model_parallel_region(x_full, dim=0) + 0 + x = scatter_to_model_parallel_region(x_full) + 0 # BE VERY CAREFUL, inside of this context do not use any variables # in other partitions, their gradient will not propagate! if rank == 0: @@ -384,28 +329,13 @@ def test_scatter_prod_reduce(): ).all(), f"Failed closeness check for {key}" -def layer(x, target_rank): - rank = dist.get_rank() - - x_full = gather_from_model_parallel_region(x, 0) - x_prod = x_full.prod() - # backward graphs need to be same operation wise - # otherwise might miss a dist sync - if rank == target_rank: - x = x * 0 + x_prod - else: - x = x * 0 + x_prod * 0.0 + (rank + 1) - return x - - def embeddings_and_graph_init(atomic_numbers, edge_index): if gp_utils.initialized(): - node_partition = gp_utils.scatter_to_model_parallel_region( - torch.arange(len(atomic_numbers)).to(atomic_numbers.device) - ) - assert ( - node_partition.numel() > 0 - ), "Looks like there is no atoms in this graph paralell partition. Cannot proceed" + node_partition = torch.split( + torch.arange(atomic_numbers.shape[0]), + size_list_fn(atomic_numbers.shape[0], gp_utils.get_gp_world_size()), + )[gp_utils.get_gp_rank()] + edge_partition = torch.where( torch.logical_and( edge_index[1] >= node_partition.min(), @@ -416,32 +346,33 @@ def embeddings_and_graph_init(atomic_numbers, edge_index): graph_dict = { "node_offset": node_partition.min().item(), "edge_index": edge_index[:, edge_partition], + "natoms": atomic_numbers.shape[0], } - node_embeddings = atomic_numbers[node_partition] else: graph_dict = { "node_offset": 0, "edge_index": edge_index, + "natoms": atomic_numbers.shape[0], } - node_embeddings = atomic_numbers - return node_embeddings, graph_dict + return atomic_numbers, graph_dict # test for one rank to return a product and rest return 0 -def simple_layer(x, edge_index, node_offset, n=3): +def simple_layer(x, edge_index, node_offset, natoms, n=3): + x_source = x[edge_index[0]] + x_target = x[edge_index[1]] if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, dim=0) - x_source = x_full[edge_index[0]] - x_target = x_full[edge_index[1]] dp_rank = gp_utils.get_dp_rank() + local_atoms = size_list_fn(natoms, gp_utils.get_gp_world_size())[ + gp_utils.get_gp_rank() + ] else: - x_source = x[edge_index[0]] - x_target = x[edge_index[1]] if dist.is_initialized(): dp_rank = dist.get_rank() else: dp_rank = 0.0 + local_atoms = x.shape[0] # make sure different ddp ranks have different outputs # similar to seeing diffferent data batches @@ -451,14 +382,19 @@ def simple_layer(x, edge_index, node_offset, n=3): edge_embeddings = (x_source + 1).pow(n) * (x_target + 1).pow(n) new_node_embedding = torch.zeros( - (x.shape[0],) + edge_embeddings.shape[1:], + (local_atoms,) + edge_embeddings.shape[1:], dtype=edge_embeddings.dtype, device=edge_embeddings.device, ) new_node_embedding.index_add_(0, edge_index[1] - node_offset, edge_embeddings) - return new_node_embedding + if gp_utils.initialized(): + return gp_utils.gather_from_model_parallel_region_sum_grad( + new_node_embedding, natoms, gloo_backend=True + ) + else: + return new_node_embedding class SimpleNet(nn.Module): @@ -479,12 +415,21 @@ def forward(self, atomic_numbers, edge_index): all_node_embeddings[-1], graph_dict["edge_index"], node_offset=graph_dict["node_offset"], + natoms=graph_dict["natoms"], n=self.n, ) ) final_node_embeddings = all_node_embeddings[-1] + if gp_utils.initialized(): + local_atoms = size_list_fn( + graph_dict["natoms"], gp_utils.get_gp_world_size() + )[gp_utils.get_gp_rank()] + final_node_embeddings = final_node_embeddings[ + graph_dict["node_offset"] : graph_dict["node_offset"] + local_atoms + ] + # only 1 system energy_part = torch.zeros( 1, device=atomic_numbers.device, dtype=atomic_numbers.dtype From 6e8c41b1b03600699cbda31e455f9995cb6b923c Mon Sep 17 00:00:00 2001 From: misko Date: Sat, 25 Oct 2025 00:24:26 +0000 Subject: [PATCH 02/13] fix tests --- src/fairchem/core/common/gp_utils.py | 44 +++++++++++++++++++ src/fairchem/core/models/uma/escn_md.py | 14 +++++- src/fairchem/core/models/uma/escn_md_block.py | 12 ++++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 3a3652cf69..08e99ecad4 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -217,6 +217,29 @@ def backward(ctx, grad_output: torch.Tensor): ) +def gather_from_model_parallel_region( + input: torch.Tensor, + natoms: int, +) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + world_size = get_gp_world_size() + size_list = size_list_fn(natoms, world_size) + + input = pad_input( + input, natoms // world_size + (1 if natoms % world_size != 0 else 0) + ) + + tensor_list_w_padding = GatherFromModelParallelRegionGradPadded.apply(input, False) + + return torch.cat( + [ + t.narrow(0, 0, s) if t.shape[0] != s else t + for t, s in zip(tensor_list_w_padding, size_list) + ], + dim=0, + ) + + def gather_from_model_parallel_region_sum_grad( input: torch.Tensor, natoms: int, @@ -266,6 +289,27 @@ def gather_from_model_parallel_region_sum_grad_async( return tensor_list_w_padding, handle +class GatherFromModelParallelRegionGradPadded(torch.autograd.Function): + @staticmethod + @torch.compiler.disable + def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + ctx.rank = get_gp_rank() + ctx.group = get_gp_group() + tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] + if async_op: + _tls.handle = dist.all_gather( + tensor_list, input, group=ctx.group, async_op=async_op + ) + else: + dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + return tuple(tensor_list) + + @staticmethod + @torch.compiler.disable + def backward(ctx, *grad_outputs): + return grad_outputs[ctx.rank], None + + class GatherFromModelParallelRegionSumGradPadded(torch.autograd.Function): @staticmethod @torch.compiler.disable diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 0555d6daee..3895f235b2 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from torch import distributed as dist from torch.profiler import record_function from fairchem.core.common import gp_utils @@ -405,6 +406,8 @@ def _generate_graph(self, data_dict): @conditional_grad(torch.enable_grad()) def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: + gloo_backend = (not gp_utils.initialized()) or dist.get_backend() == "gloo" + data_dict["atomic_numbers"] = data_dict["atomic_numbers"].long() data_dict["atomic_numbers_full"] = data_dict["atomic_numbers"] data_dict["batch_full"] = data_dict["batch"] @@ -505,6 +508,8 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, + total_atoms=data_dict["atomic_numbers_full"].shape[0], + gloo_backend=gloo_backend, sys_node_embedding=sys_node_embedding, node_offset=graph_dict["node_offset"], ) @@ -648,8 +653,11 @@ def forward( outputs[energy_key] = {"energy": energy} if self.wrap_property else energy embeddings = emb["node_embedding"].detach() + # TODO we should remove this for MD runs if gp_utils.initialized(): - embeddings = gp_utils.gather_from_model_parallel_region(embeddings, dim=0) + embeddings = gp_utils.gather_from_model_parallel_region( + embeddings, data["atomic_numbers_full"].shape[0] + ) outputs["embeddings"] = ( {"embeddings": embeddings} if self.wrap_property else embeddings @@ -781,7 +789,9 @@ def forward(self, data_dict: AtomicData, emb: dict[str, torch.Tensor]): forces = forces.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() if gp_utils.initialized(): - forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + forces = gp_utils.gather_from_model_parallel_region( + forces, data_dict["atomic_numbers_full"].shape[0] + ) return {"forces": forces} diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 8c3efa46e2..7de43eca3b 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -125,11 +125,15 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, + total_atoms, + gloo_backend, node_offset: int = 0, ): # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, dim=0) + x_full = gp_utils.gather_from_model_parallel_region_sum_grad( + x, total_atoms, gloo_backend=gloo_backend + ) else: x_full = x @@ -387,6 +391,8 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, + total_atoms, + gloo_backend, sys_node_embedding=None, node_offset: int = 0, ): @@ -405,7 +411,9 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, - node_offset, + total_atoms=total_atoms, + gloo_backend=gloo_backend, + node_offset=node_offset, ) x = x + x_res From 36596168b0dc630d88a28c9476be78fe533db775 Mon Sep 17 00:00:00 2001 From: misko Date: Thu, 30 Oct 2025 23:59:55 +0000 Subject: [PATCH 03/13] increase atol --- tests/core/units/mlip_unit/test_predict.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/core/units/mlip_unit/test_predict.py b/tests/core/units/mlip_unit/test_predict.py index 606ab16bc1..9f7e2cc744 100644 --- a/tests/core/units/mlip_unit/test_predict.py +++ b/tests/core/units/mlip_unit/test_predict.py @@ -14,7 +14,7 @@ from tests.conftest import seed_everywhere FORCE_TOL = 1e-4 -ATOL = 1e-5 +ATOL = 2e-5 def get_fcc_carbon_xtal( @@ -179,6 +179,7 @@ def test_parallel_predict_unit(workers, device): # Thus out-of-plane component is simply the x-component of the forces. # --------------------------------------------------------------------------- + def _random_rotation_matrix(rng: np.random.Generator) -> np.ndarray: """Generate a 3D rotation matrix from two angles in [0, 2π). @@ -210,8 +211,7 @@ def test_rotational_invariance_out_of_plane(mol_name): atoms.info.update({"charge": 0, "spin": 1}) atoms.calc = calc - orig_positions = atoms.get_positions().copy()\ - + orig_positions = atoms.get_positions().copy() n_rot = 50 # fewer rotations for speed for _ in range(n_rot): R = _random_rotation_matrix(rng) @@ -220,8 +220,7 @@ def test_rotational_invariance_out_of_plane(mol_name): rot_forces = atoms.get_forces() # Unrotate forces back to original frame (covariant transformation) unrot_forces = rot_forces @ R - assert (np.abs(unrot_forces[:,0]) Date: Tue, 11 Nov 2025 01:35:28 +0000 Subject: [PATCH 04/13] remove async --- src/fairchem/core/models/uma/escn_md.py | 4 ---- src/fairchem/core/models/uma/escn_md_block.py | 7 +------ 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index b72ca3d102..e3a506097a 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -from torch import distributed as dist from torch.profiler import record_function from fairchem.core.common import gp_utils @@ -406,8 +405,6 @@ def _generate_graph(self, data_dict): @conditional_grad(torch.enable_grad()) def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: - gloo_backend = (not gp_utils.initialized()) or dist.get_backend() == "gloo" - data_dict["atomic_numbers"] = data_dict["atomic_numbers"].long() data_dict["atomic_numbers_full"] = data_dict["atomic_numbers"] data_dict["batch_full"] = data_dict["batch"] @@ -509,7 +506,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: wigner_and_M_mapping_inv, edge_envelope, total_atoms=data_dict["atomic_numbers_full"].shape[0], - gloo_backend=gloo_backend, sys_node_embedding=sys_node_embedding, node_offset=graph_dict["node_offset"], ) diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 7de43eca3b..96c0b8daaf 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -126,14 +126,11 @@ def forward( wigner_and_M_mapping_inv, edge_envelope, total_atoms, - gloo_backend, node_offset: int = 0, ): # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad( - x, total_atoms, gloo_backend=gloo_backend - ) + x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, total_atoms) else: x_full = x @@ -392,7 +389,6 @@ def forward( wigner_and_M_mapping_inv, edge_envelope, total_atoms, - gloo_backend, sys_node_embedding=None, node_offset: int = 0, ): @@ -412,7 +408,6 @@ def forward( wigner_and_M_mapping_inv, edge_envelope, total_atoms=total_atoms, - gloo_backend=gloo_backend, node_offset=node_offset, ) x = x + x_res From 32eddbb967bf518eaef376a16d3756f90f7b2814 Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 11 Nov 2025 01:41:01 +0000 Subject: [PATCH 05/13] types --- src/fairchem/core/common/gp_utils.py | 61 ++++++---------------------- 1 file changed, 13 insertions(+), 48 deletions(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 08e99ecad4..ae7672e908 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -174,7 +174,7 @@ def get_gp_world_size() -> int: ########## DIST METHODS ########## -def size_list_fn(size, parts): +def size_list_fn(size: int, parts: int) -> list[int]: return [size // parts + (1 if idx < size % parts else 0) for idx in range(parts)] @@ -243,7 +243,6 @@ def gather_from_model_parallel_region( def gather_from_model_parallel_region_sum_grad( input: torch.Tensor, natoms: int, - gloo_backend: bool, ) -> torch.Tensor: assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" world_size = get_gp_world_size() @@ -253,15 +252,13 @@ def gather_from_model_parallel_region_sum_grad( input, natoms // world_size + (1 if natoms % world_size != 0 else 0) ) - if gloo_backend: + if dist.get_backend() == "gloo": tensor_list_w_padding = GatherFromModelParallelRegionSumGradPaddedGLOO.apply( - input, False + input ) else: # tensor_list_w_padding = all_gather(input, group=get_gp_group()) - tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply( - input, False - ) + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(input) return torch.cat( [ @@ -272,57 +269,30 @@ def gather_from_model_parallel_region_sum_grad( ) -def gather_from_model_parallel_region_sum_grad_async( - input: torch.Tensor, natoms: int -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - world_size = get_gp_world_size() - - input = pad_input( - input, natoms // world_size + (1 if natoms % world_size != 0 else 0) - ) - - tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply( - input, True - ) - handle = _tls.handle - return tensor_list_w_padding, handle - - class GatherFromModelParallelRegionGradPadded(torch.autograd.Function): @staticmethod @torch.compiler.disable - def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.rank = get_gp_rank() ctx.group = get_gp_group() tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] - if async_op: - _tls.handle = dist.all_gather( - tensor_list, input, group=ctx.group, async_op=async_op - ) - else: - dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + dist.all_gather(tensor_list, input, group=ctx.group) return tuple(tensor_list) @staticmethod @torch.compiler.disable def backward(ctx, *grad_outputs): - return grad_outputs[ctx.rank], None + return grad_outputs[ctx.rank] class GatherFromModelParallelRegionSumGradPadded(torch.autograd.Function): @staticmethod @torch.compiler.disable - def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.rank = get_gp_rank() ctx.group = get_gp_group() tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] - if async_op: - _tls.handle = dist.all_gather( - tensor_list, input, group=ctx.group, async_op=async_op - ) - else: - dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + dist.all_gather(tensor_list, input, group=ctx.group) return tuple(tensor_list) @staticmethod @@ -330,23 +300,18 @@ def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: def backward(ctx, *grad_outputs): local_grad_output = grad_outputs[ctx.rank] output_tensor = torch.empty_like(local_grad_output) - return reduce_scatter(output_tensor, grad_outputs, group=ctx.group), None + return reduce_scatter(output_tensor, grad_outputs, group=ctx.group) class GatherFromModelParallelRegionSumGradPaddedGLOO(torch.autograd.Function): @staticmethod @torch.compiler.disable - def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.rank = get_gp_rank() ctx.group = get_gp_group() ctx.shape = input.shape tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] - if async_op: - _tls.handle = dist.all_gather( - tensor_list, input, group=ctx.group, async_op=async_op - ) - else: - dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + dist.all_gather(tensor_list, input, group=ctx.group) return tuple(tensor_list) @staticmethod @@ -357,7 +322,7 @@ def backward(ctx, *grad_outputs): result = grad_output[ ctx.padded_size * ctx.rank : ctx.padded_size * ctx.rank + ctx.shape[0] ] - return result, None + return result def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: From 832f6a2e22ea46313e51d2261a726a881c25ce85 Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 11 Nov 2025 18:33:18 +0000 Subject: [PATCH 06/13] fix tests --- tests/core/common/test_gp_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index f396d302eb..0bb00a5906 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -85,7 +85,7 @@ def test_scatter_tensors( def scatter_gather_fn(input: torch.Tensor): x = scatter_to_model_parallel_region(input) - return gather_from_model_parallel_region_sum_grad(x, input.shape[0], True) + return gather_from_model_parallel_region_sum_grad(x, input.shape[0]) @pytest.mark.parametrize( @@ -207,7 +207,7 @@ def gather_sum_bwd_test(rank=-1): rank = dist.get_rank() x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) x_full = gather_from_model_parallel_region_sum_grad( - x, gp_utils.get_gp_world_size(), True + x, gp_utils.get_gp_world_size() ) energy = (x_full.prod() + rank + 1) ** 2 # sum @@ -391,7 +391,7 @@ def simple_layer(x, edge_index, node_offset, natoms, n=3): if gp_utils.initialized(): return gp_utils.gather_from_model_parallel_region_sum_grad( - new_node_embedding, natoms, gloo_backend=True + new_node_embedding, natoms ) else: return new_node_embedding From 620e646251a360c8579b6d12e702113e2f3651fe Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 11 Nov 2025 20:13:05 +0000 Subject: [PATCH 07/13] fix tests --- src/fairchem/core/common/gp_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index ae7672e908..cd87e19352 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -213,7 +213,7 @@ def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor): return gather_from_model_parallel_region_sum_grad( - grad_output, sum(ctx.split_sizes), False + grad_output, sum(ctx.split_sizes) ) From 343c535eba84889736dc565baeb22c7468b1edf2 Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 11 Nov 2025 20:17:24 +0000 Subject: [PATCH 08/13] rename total_atoms --- src/fairchem/core/models/uma/escn_md.py | 4 +++- src/fairchem/core/models/uma/escn_md_block.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index e3a506097a..7cf0fa4470 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -505,7 +505,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, - total_atoms=data_dict["atomic_numbers_full"].shape[0], + total_atoms_across_gp_ranks=data_dict["atomic_numbers_full"].shape[ + 0 + ], sys_node_embedding=sys_node_embedding, node_offset=graph_dict["node_offset"], ) diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 96c0b8daaf..596dae0806 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -125,12 +125,14 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, - total_atoms, + total_atoms_across_gp_ranks, node_offset: int = 0, ): # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, total_atoms) + x_full = gp_utils.gather_from_model_parallel_region_sum_grad( + x, total_atoms_across_gp_ranks + ) else: x_full = x From d8c0197d01cecacd12255b235a644c14568dfb7c Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 11 Nov 2025 20:23:53 +0000 Subject: [PATCH 09/13] typo --- src/fairchem/core/models/uma/escn_md_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 596dae0806..b84b3f1008 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -390,7 +390,7 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, - total_atoms, + total_atoms_across_gp_ranks, sys_node_embedding=None, node_offset: int = 0, ): @@ -409,7 +409,7 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, edge_envelope, - total_atoms=total_atoms, + total_atoms_across_gp_ranks=total_atoms_across_gp_ranks, node_offset=node_offset, ) x = x + x_res From 883de3f7d48bc831def382ac059befd83a47fb23 Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 11 Nov 2025 21:06:07 +0000 Subject: [PATCH 10/13] missing one more --- src/fairchem/core/common/gp_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index cd87e19352..a9fd5aaf14 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -229,7 +229,7 @@ def gather_from_model_parallel_region( input, natoms // world_size + (1 if natoms % world_size != 0 else 0) ) - tensor_list_w_padding = GatherFromModelParallelRegionGradPadded.apply(input, False) + tensor_list_w_padding = GatherFromModelParallelRegionGradPadded.apply(input) return torch.cat( [ From 9be51251f487756b080b6a147d0f7765adb8d4af Mon Sep 17 00:00:00 2001 From: misko Date: Wed, 12 Nov 2025 00:35:59 +0000 Subject: [PATCH 11/13] remove extra gloo --- src/fairchem/core/common/gp_utils.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index a9fd5aaf14..1e4e736c16 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -253,9 +253,7 @@ def gather_from_model_parallel_region_sum_grad( ) if dist.get_backend() == "gloo": - tensor_list_w_padding = GatherFromModelParallelRegionSumGradPaddedGLOO.apply( - input - ) + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(input) else: # tensor_list_w_padding = all_gather(input, group=get_gp_group()) tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(input) @@ -298,25 +296,11 @@ def forward(ctx, input: torch.Tensor) -> torch.Tensor: @staticmethod @torch.compiler.disable def backward(ctx, *grad_outputs): - local_grad_output = grad_outputs[ctx.rank] - output_tensor = torch.empty_like(local_grad_output) - return reduce_scatter(output_tensor, grad_outputs, group=ctx.group) + if dist.get_backend() != "gloo": + local_grad_output = grad_outputs[ctx.rank] + output_tensor = torch.empty_like(local_grad_output) + return reduce_scatter(output_tensor, grad_outputs, group=ctx.group) - -class GatherFromModelParallelRegionSumGradPaddedGLOO(torch.autograd.Function): - @staticmethod - @torch.compiler.disable - def forward(ctx, input: torch.Tensor) -> torch.Tensor: - ctx.rank = get_gp_rank() - ctx.group = get_gp_group() - ctx.shape = input.shape - tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] - dist.all_gather(tensor_list, input, group=ctx.group) - return tuple(tensor_list) - - @staticmethod - @torch.compiler.disable - def backward(ctx, *grad_outputs): grad_output = all_reduce(torch.cat(grad_outputs, dim=0), group=ctx.group) ctx.padded_size = grad_outputs[0].shape[0] result = grad_output[ From b7a0b0934d6641f31339145b3d7c0c3b0879d32b Mon Sep 17 00:00:00 2001 From: misko Date: Wed, 12 Nov 2025 00:36:17 +0000 Subject: [PATCH 12/13] remove extra gloo --- src/fairchem/core/common/gp_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 1e4e736c16..260dbcde28 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -252,11 +252,7 @@ def gather_from_model_parallel_region_sum_grad( input, natoms // world_size + (1 if natoms % world_size != 0 else 0) ) - if dist.get_backend() == "gloo": - tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(input) - else: - # tensor_list_w_padding = all_gather(input, group=get_gp_group()) - tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(input) + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(input) return torch.cat( [ From 81ca6d0d1a9716d6ac653f29a10264e3a2c87e69 Mon Sep 17 00:00:00 2001 From: misko Date: Wed, 12 Nov 2025 00:49:30 +0000 Subject: [PATCH 13/13] fix tests --- src/fairchem/core/common/gp_utils.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 260dbcde28..1ac1d758aa 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -285,6 +285,8 @@ class GatherFromModelParallelRegionSumGradPadded(torch.autograd.Function): def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.rank = get_gp_rank() ctx.group = get_gp_group() + if dist.get_backend() == "gloo": + ctx.shape = input.shape tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] dist.all_gather(tensor_list, input, group=ctx.group) return tuple(tensor_list) @@ -292,17 +294,16 @@ def forward(ctx, input: torch.Tensor) -> torch.Tensor: @staticmethod @torch.compiler.disable def backward(ctx, *grad_outputs): - if dist.get_backend() != "gloo": - local_grad_output = grad_outputs[ctx.rank] - output_tensor = torch.empty_like(local_grad_output) - return reduce_scatter(output_tensor, grad_outputs, group=ctx.group) - - grad_output = all_reduce(torch.cat(grad_outputs, dim=0), group=ctx.group) - ctx.padded_size = grad_outputs[0].shape[0] - result = grad_output[ - ctx.padded_size * ctx.rank : ctx.padded_size * ctx.rank + ctx.shape[0] - ] - return result + if dist.get_backend() == "gloo": + grad_output = all_reduce(torch.cat(grad_outputs, dim=0), group=ctx.group) + ctx.padded_size = grad_outputs[0].shape[0] + result = grad_output[ + ctx.padded_size * ctx.rank : ctx.padded_size * ctx.rank + ctx.shape[0] + ] + return result + local_grad_output = grad_outputs[ctx.rank] + output_tensor = torch.empty_like(local_grad_output) + return reduce_scatter(output_tensor, grad_outputs, group=ctx.group) def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: