diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 48d2226cb8..3a3652cf69 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -9,12 +9,11 @@ import contextlib import logging -from typing import Any +import threading -import numpy as np import torch from torch import distributed as dist -from torch.distributed.nn.functional import all_reduce +from torch.distributed.nn.functional import all_reduce, reduce_scatter """ Functions to support graph parallel training. @@ -27,6 +26,32 @@ _GRAPH_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None +_tls = threading.local() + + +def pad_input(input: torch.Tensor, padded_size: int): + # pad using functional + # if input.shape[0]!=padded_size: + # input=torch.nn.functional.pad(input,(0,0,0,0,0,1)).contiguous() + + # pad using manual tensor cat + if input.shape[0] != padded_size: + input = torch.cat( + [ + input, + torch.zeros( + (padded_size - input.shape[0], *input.shape[1:]), + device=input.device, + dtype=input.dtype, + ), + ], + dim=0, + ) + + assert input.shape[0] == padded_size + + return input + def ensure_div(a: int, b: int) -> None: assert a % b == 0 @@ -149,190 +174,153 @@ def get_gp_world_size() -> int: ########## DIST METHODS ########## -def pad_tensor( - tensor: torch.Tensor, dim: int = -1, target_size: int | None = None -) -> torch.Tensor: - size = tensor.size(dim) - if target_size is None: - world_size = get_gp_world_size() - pad_size = 0 if size % world_size == 0 else world_size - size % world_size - else: - pad_size = target_size - size - if pad_size == 0: - return tensor - pad_shape = list(tensor.shape) - pad_shape[dim] = pad_size - padding = torch.empty(pad_shape, device=tensor.device, dtype=tensor.dtype) - return torch.cat([tensor, padding], dim=dim) - - -def trim_tensor(tensor: torch.Tensor, sizes: torch.Tensor | None = None, dim: int = 0): - size = tensor.size(dim) - world_size = get_gp_world_size() - if size % world_size == 0: - return tensor, sizes - trim_size = size - size % world_size - if dim == 0: - tensor = tensor[:trim_size] - elif dim == 1: - tensor = tensor[:, :trim_size] - else: - raise ValueError - if sizes is not None: - sizes[-1] = sizes[-1] - size % world_size - return tensor, sizes - - -def _tensor_to_split_partitions(tensor: torch.Tensor, dim: int = -1): - group = get_gp_group() - num_parts = dist.get_world_size(group=group) - return [len(part) for part in np.array_split(np.zeros(tensor.size(dim)), num_parts)] - - -def _split_tensor( - tensor: torch.Tensor, - dim: int = -1, - contiguous_chunks: bool = False, -): - tensor_list = torch.split(tensor, _tensor_to_split_partitions(tensor, dim), dim=dim) - if contiguous_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - return tensor_list - - -def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor: - group = get_gp_group() - if ctx: - ctx.mark_dirty(input) - dist.all_reduce(input, group=group) - return input - - -def _split(input: torch.Tensor, dim: int = -1) -> torch.Tensor: - rank = get_gp_rank() - input_list = _split_tensor(input, dim=dim) - return input_list[rank].clone().contiguous() - - -def _gather(input: torch.Tensor, dim: int = -1) -> torch.Tensor: - group = get_gp_group() - rank = get_gp_rank() - world_size = dist.get_world_size(group=group) - if world_size == 1: - return input - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input - dist.all_gather(tensor_list, input, group=group) - return torch.cat(tensor_list, dim=dim).contiguous() - - -def _gather_with_padding(input: torch.Tensor, dim: int = -1) -> torch.Tensor: - group = get_gp_group() - rank = get_gp_rank() - world_size = dist.get_world_size(group=group) - if world_size == 1: - return input - - # Gather sizes - size_list = [ - torch.empty(1, device=input.device, dtype=torch.long) for _ in range(world_size) - ] - size = torch.tensor([input.size(dim)], device=input.device, dtype=torch.long) - size_list[rank] = size - dist.all_gather(size_list, size, group=group) - - # Gather the inputs - max_size = int(max([size.item() for size in size_list])) - input = pad_tensor(input, dim, max_size) - shape = list(input.shape) - shape[dim] = max_size - tensor_list = [ - torch.empty(shape, device=input.device, dtype=input.dtype) - for _ in range(world_size) - ] - - dist.all_gather(tensor_list, input, group=group) - tensor_list[rank] = input # pop back in our local copy (requires grad) - - # Trim and cat - return torch.cat( - [tensor.narrow(dim, 0, size) for tensor, size in zip(tensor_list, size_list)], - dim=dim, - ).contiguous() - +def size_list_fn(size, parts): + return [size // parts + (1 if idx < size % parts else 0) for idx in range(parts)] -class CopyToModelParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, input: torch.Tensor) -> torch.Tensor: - return input - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - return _reduce(None, grad_output) +def reduce_from_model_parallel_region(input: torch.Tensor) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + return ReduceFromModelParallelRegion.apply(input) class ReduceFromModelParallelRegion(torch.autograd.Function): @staticmethod + @torch.compiler.disable def forward(ctx, input: torch.Tensor) -> torch.Tensor: # return _reduce(ctx, input) # this operates in place return all_reduce(input, group=get_gp_group()) # this operats out of place @staticmethod + @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return grad_output +def scatter_to_model_parallel_region(input: torch.Tensor) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + return ScatterToModelParallelRegion.apply(input) + + # this returns the values in place class ScatterToModelParallelRegion(torch.autograd.Function): @staticmethod + @torch.compiler.disable def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - result = _split(input, dim) - ctx.save_for_backward(torch.tensor(dim)) - return result + ctx.split_sizes = size_list_fn(input.shape[0], get_gp_world_size()) + return input.split(ctx.split_sizes)[get_gp_rank()] @staticmethod + @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - return _gather_with_padding(grad_output.clone(), dim.item()), None + return gather_from_model_parallel_region_sum_grad( + grad_output, sum(ctx.split_sizes), False + ) -class GatherFromModelParallelRegion(torch.autograd.Function): +def gather_from_model_parallel_region_sum_grad( + input: torch.Tensor, + natoms: int, + gloo_backend: bool, +) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + world_size = get_gp_world_size() + size_list = size_list_fn(natoms, world_size) + + input = pad_input( + input, natoms // world_size + (1 if natoms % world_size != 0 else 0) + ) + + if gloo_backend: + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPaddedGLOO.apply( + input, False + ) + else: + # tensor_list_w_padding = all_gather(input, group=get_gp_group()) + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply( + input, False + ) + + return torch.cat( + [ + t.narrow(0, 0, s) if t.shape[0] != s else t + for t, s in zip(tensor_list_w_padding, size_list) + ], + dim=0, + ) + + +def gather_from_model_parallel_region_sum_grad_async( + input: torch.Tensor, natoms: int +) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + world_size = get_gp_world_size() + + input = pad_input( + input, natoms // world_size + (1 if natoms % world_size != 0 else 0) + ) + + tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply( + input, True + ) + handle = _tls.handle + return tensor_list_w_padding, handle + + +class GatherFromModelParallelRegionSumGradPadded(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - ctx.save_for_backward(torch.tensor(dim)) - return _gather_with_padding(input, dim) + @torch.compiler.disable + def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + ctx.rank = get_gp_rank() + ctx.group = get_gp_group() + tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] + if async_op: + _tls.handle = dist.all_gather( + tensor_list, input, group=ctx.group, async_op=async_op + ) + else: + dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + return tuple(tensor_list) @staticmethod - def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - result = _split(grad_output, dim.item()) - return result, None + @torch.compiler.disable + def backward(ctx, *grad_outputs): + local_grad_output = grad_outputs[ctx.rank] + output_tensor = torch.empty_like(local_grad_output) + return reduce_scatter(output_tensor, grad_outputs, group=ctx.group), None -class GatherFromModelParallelRegionSumGrad(torch.autograd.Function): +class GatherFromModelParallelRegionSumGradPaddedGLOO(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - ctx.save_for_backward(torch.tensor(dim)) - return _gather_with_padding(input, dim) + @torch.compiler.disable + def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor: + ctx.rank = get_gp_rank() + ctx.group = get_gp_group() + ctx.shape = input.shape + tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())] + if async_op: + _tls.handle = dist.all_gather( + tensor_list, input, group=ctx.group, async_op=async_op + ) + else: + dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op) + return tuple(tensor_list) @staticmethod - def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - group = get_gp_group() - # use dist internal # does not work - # reduced_grad_output = grad_output.clone() - # dist.all_reduce( - # reduced_grad_output, group=group - # ) # This is an inplace operation - # grad_output = reduced_grad_output - - # use functional version instead - grad_output = all_reduce(grad_output, group=group) - - result = _split(grad_output, dim.item()) + @torch.compiler.disable + def backward(ctx, *grad_outputs): + grad_output = all_reduce(torch.cat(grad_outputs, dim=0), group=ctx.group) + ctx.padded_size = grad_outputs[0].shape[0] + result = grad_output[ + ctx.padded_size * ctx.rank : ctx.padded_size * ctx.rank + ctx.shape[0] + ] return result, None +def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: + assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" + return ScaleBackwardGrad.apply(input) + + # Leave forward untouched but upscale the gradient by a factor of gp_group_size # DDP reduces a mean across the loss, if we have gp_group_size=2 and 6 ranks # that means we do (a_1+a_2+a_3+b_1+b_2+b_3)/6 in ddp mean. This gets us the @@ -344,45 +332,11 @@ def backward(ctx, grad_output: torch.Tensor): # avoid over head communication class ScaleBackwardGrad(torch.autograd.Function): @staticmethod + @torch.compiler.disable def forward(ctx, input: torch.Tensor) -> torch.Tensor: return input @staticmethod + @torch.compiler.disable def backward(ctx, grad_output: torch.Tensor): return dist.get_world_size(get_gp_group()) * grad_output - - -def copy_to_model_parallel_region(input: torch.Tensor) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return CopyToModelParallelRegion.apply(input) - - -def reduce_from_model_parallel_region(input: torch.Tensor) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return ReduceFromModelParallelRegion.apply(input) - - -def scatter_to_model_parallel_region( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return ScatterToModelParallelRegion.apply(input, dim) - - -def gather_from_model_parallel_region( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return GatherFromModelParallelRegion.apply(input, dim) - - -def gather_from_model_parallel_region_sum_grad( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return GatherFromModelParallelRegionSumGrad.apply(input, dim) - - -def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: - assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return ScaleBackwardGrad.apply(input) diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 4e48d3790c..8b3a6c3273 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from torch import distributed as dist from torch.profiler import record_function from fairchem.core.common import gp_utils @@ -408,6 +409,8 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: data_dict["atomic_numbers"] = data_dict["atomic_numbers"].long() data_dict["atomic_numbers_full"] = data_dict["atomic_numbers"] data_dict["batch_full"] = data_dict["batch"] + natoms = data_dict["atomic_numbers_full"].shape[0] + gloo_backend = (not gp_utils.initialized()) or dist.get_backend() == "gloo" csd_mixed_emb = self.csd_embedding( charge=data_dict["charge"], @@ -455,6 +458,7 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: x_message[:, 0, :] = self.sphere_embedding(data_dict["atomic_numbers"]) sys_node_embedding = csd_mixed_emb[data_dict["batch"]] + sys_node_embedding_full = csd_mixed_emb[data_dict["batch_full"]] x_message[:, 0, :] = x_message[:, 0, :] + sys_node_embedding ### @@ -487,7 +491,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: graph_dict["edge_distance"], graph_dict["edge_index"], wigner_and_M_mapping_inv, + natoms, graph_dict["node_offset"], + gloo_backend=gloo_backend, ) ############################################################### @@ -502,8 +508,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: graph_dict["edge_index"], wigner_and_M_mapping, wigner_and_M_mapping_inv, - sys_node_embedding=sys_node_embedding, + sys_node_embedding=sys_node_embedding_full, node_offset=graph_dict["node_offset"], + gloo_backend=gloo_backend, ) # Final layer norm @@ -513,6 +520,7 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: "displacement": displacement, "orig_cell": orig_cell, "batch": data_dict["batch"], + "node_offset": graph_dict["node_offset"], } return out @@ -629,39 +637,49 @@ def forward( stress_key = "stress" outputs = {} - _input = emb["node_embedding"].narrow(1, 0, 1).squeeze(1) + _input = ( + emb["node_embedding"] + .narrow(1, 0, 1) + .squeeze(1) + .narrow(0, emb["node_offset"], data["batch"].shape[0]) + ) _output = self.energy_block(_input) + node_energy = _output.view(-1, 1, 1) - energy_part = torch.zeros( + + total_energies = torch.zeros( len(data["natoms"]), device=data["pos"].device, dtype=node_energy.dtype ) - energy_part.index_add_(0, data["batch"], node_energy.view(-1)) + total_energies.index_add_(0, data["batch"], node_energy.view(-1)) if gp_utils.initialized(): - energy = gp_utils.reduce_from_model_parallel_region(energy_part) - else: - energy = energy_part + # TODO optimize here for MD, we already have all embeddings can skip all_gather + total_energies = gp_utils.reduce_from_model_parallel_region(total_energies) - outputs[energy_key] = {"energy": energy} if self.wrap_property else energy + outputs[energy_key] = ( + {"energy": total_energies} if self.wrap_property else total_energies + ) - embeddings = emb["node_embedding"].detach() - if gp_utils.initialized(): - embeddings = gp_utils.gather_from_model_parallel_region(embeddings, dim=0) + if not gp_utils.initialized(): + embeddings = emb["node_embedding"].detach() - outputs["embeddings"] = ( - {"embeddings": embeddings} if self.wrap_property else embeddings - ) + outputs["embeddings"] = ( + {"embeddings": embeddings} if self.wrap_property else embeddings + ) if self.regress_stress: grads = torch.autograd.grad( - [energy_part.sum()], + [node_energy.sum()], [data["pos_original"], emb["displacement"]], create_graph=self.training, ) if gp_utils.initialized(): + reduced_grad = gp_utils.reduce_from_model_parallel_region( + torch.cat([grads[0].view(-1), grads[1].view(-1)]) + ).split([grads[0].numel(), grads[1].numel()]) grads = ( - gp_utils.reduce_from_model_parallel_region(grads[0]), - gp_utils.reduce_from_model_parallel_region(grads[1]), + reduced_grad[0].reshape(grads[0].shape), + reduced_grad[1].reshape(grads[1].shape), ) forces = torch.neg(grads[0]) @@ -679,7 +697,7 @@ def forward( forces = ( -1 * torch.autograd.grad( - energy_part.sum(), data["pos"], create_graph=self.training + node_energy.sum(), data["pos"], create_graph=self.training )[0] ) if gp_utils.initialized(): @@ -710,17 +728,13 @@ def forward( emb["node_embedding"].narrow(1, 0, 1).squeeze(1) ).view(-1, 1, 1) - energy_part = torch.zeros( + energy = torch.zeros( len(data_dict["natoms"]), device=node_energy.device, dtype=node_energy.dtype, ) - energy_part.index_add_(0, data_dict["batch"], node_energy.view(-1)) - if gp_utils.initialized(): - energy = gp_utils.reduce_from_model_parallel_region(energy_part) - else: - energy = energy_part + energy.index_add_(0, data_dict["batch_full"], node_energy.view(-1)) if self.reduce == "sum": return {"energy": energy} @@ -777,8 +791,6 @@ def forward(self, data_dict: AtomicData, emb: dict[str, torch.Tensor]): forces = self.linear(emb["node_embedding"].narrow(1, 0, 4)) forces = forces.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() - if gp_utils.initialized(): - forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) return {"forces": forces} diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 30e6fe119c..837d6b03e5 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -20,6 +20,7 @@ GateActivation, SeparableS2Activation_M, ) +from fairchem.core.models.uma.nn.embedding_dev import size_list_fn from fairchem.core.models.uma.nn.layer_norm import ( get_normalization_layer, ) @@ -125,17 +126,21 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, node_offset: int = 0, + gloo_backend: bool = True, ): - # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. - if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, dim=0) - else: - x_full = x + full_natoms = x.shape[0] + local_natoms = ( + x.shape[0] + if not gp_utils.initialized() + else size_list_fn(x.shape[0], gp_utils.get_gp_world_size())[ + gp_utils.get_gp_rank() + ] + ) if self.activation_checkpoint_chunk_size is None: - return self.forward_chunk( - x_full, - x.shape[0], + x = self.forward_chunk( + x, + local_natoms, x_edge, edge_distance, edge_index, @@ -143,50 +148,61 @@ def forward( wigner_and_M_mapping_inv, node_offset, ) - edge_index_partitions = edge_index.split( - self.activation_checkpoint_chunk_size, dim=1 - ) - wigner_partitions = wigner_and_M_mapping.split( - self.activation_checkpoint_chunk_size, dim=0 - ) - wigner_inv_partitions = wigner_and_M_mapping_inv.split( - self.activation_checkpoint_chunk_size, dim=0 - ) - edge_distance_parititons = edge_distance.split( - self.activation_checkpoint_chunk_size, dim=0 - ) - x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) - new_embeddings = [] - # when chunking, we need to keep track of the start index of the chunk and give this information - # to the mole layers - ac_mole_start_idx = 0 - - for idx in range(len(edge_index_partitions)): - new_embeddings.append( - torch.utils.checkpoint.checkpoint( - self.forward_chunk, - x_full, - x.shape[0], - x_edge_partitions[idx], - edge_distance_parititons[idx], - edge_index_partitions[idx], - wigner_partitions[idx], - wigner_inv_partitions[idx], - node_offset, - ac_mole_start_idx, - use_reentrant=False, - ) + else: + edge_index_partitions = edge_index.split( + self.activation_checkpoint_chunk_size, dim=1 ) - ac_mole_start_idx += edge_index_partitions[idx].shape[1] + wigner_partitions = wigner_and_M_mapping.split( + self.activation_checkpoint_chunk_size, dim=0 + ) + wigner_inv_partitions = wigner_and_M_mapping_inv.split( + self.activation_checkpoint_chunk_size, dim=0 + ) + edge_distance_parititons = edge_distance.split( + self.activation_checkpoint_chunk_size, dim=0 + ) + x_edge_partitions = x_edge.split( + self.activation_checkpoint_chunk_size, dim=0 + ) + new_embeddings = [] + # when chunking, we need to keep track of the start index of the chunk and give this information + # to the mole layers + ac_mole_start_idx = 0 + + for idx in range(len(edge_index_partitions)): + new_embeddings.append( + torch.utils.checkpoint.checkpoint( + self.forward_chunk, + x, + local_natoms, + x_edge_partitions[idx], + edge_distance_parititons[idx], + edge_index_partitions[idx], + wigner_partitions[idx], + wigner_inv_partitions[idx], + node_offset, + ac_mole_start_idx, + use_reentrant=False, + ) + ) + ac_mole_start_idx += edge_index_partitions[idx].shape[1] - if len(new_embeddings) > 8: - new_embeddings = [torch.stack(new_embeddings).sum(axis=0)] - return torch.stack(new_embeddings).sum(axis=0) + if len(new_embeddings) > 8: + new_embeddings = [torch.stack(new_embeddings).sum(axis=0)] + + x = torch.stack(new_embeddings).sum(axis=0) + + # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. + if gp_utils.initialized(): + x = gp_utils.gather_from_model_parallel_region_sum_grad( + x, full_natoms, gloo_backend=gloo_backend + ) + return x def forward_chunk( self, x_full, - x_original_shape, + natoms_local, x_edge, edge_distance, edge_index, @@ -226,7 +242,7 @@ def forward_chunk( # Compute the sum of the incoming neighboring messages for each target node new_embedding = torch.zeros( - (x_original_shape,) + x_message.shape[1:], + (natoms_local,) + x_message.shape[1:], dtype=x_message.dtype, device=x_message.device, ) @@ -384,6 +400,7 @@ def forward( wigner_and_M_mapping_inv, sys_node_embedding=None, node_offset: int = 0, + gloo_backend: bool = True, ): x_res = x x = self.norm_1(x) @@ -400,6 +417,7 @@ def forward( wigner_and_M_mapping, wigner_and_M_mapping_inv, node_offset, + gloo_backend=gloo_backend, ) x = x + x_res diff --git a/src/fairchem/core/models/uma/nn/embedding_dev.py b/src/fairchem/core/models/uma/nn/embedding_dev.py index b2750d8679..fa83c08b46 100644 --- a/src/fairchem/core/models/uma/nn/embedding_dev.py +++ b/src/fairchem/core/models/uma/nn/embedding_dev.py @@ -13,9 +13,15 @@ import torch import torch.nn as nn +from fairchem.core.common import gp_utils + from .radial import PolynomialEnvelope, RadialMLP +def size_list_fn(size, parts): + return [size // parts + (1 if idx < size % parts else 0) for idx in range(parts)] + + class EdgeDegreeEmbedding(torch.nn.Module): """ @@ -118,10 +124,12 @@ def forward( edge_distance, edge_index, wigner_and_M_mapping_inv, + natoms, node_offset=0, + gloo_backend: bool = True, ): if self.activation_checkpoint_chunk_size is None: - return self.forward_chunk( + x = self.forward_chunk( x, x_edge, edge_distance, @@ -129,28 +137,35 @@ def forward( wigner_and_M_mapping_inv, node_offset, ) + else: + edge_index_partitions = edge_index.split( + self.activation_checkpoint_chunk_size, dim=1 + ) + wigner_inv_partitions = wigner_and_M_mapping_inv.split( + self.activation_checkpoint_chunk_size, dim=0 + ) + edge_distance_parititons = edge_distance.split( + self.activation_checkpoint_chunk_size, dim=0 + ) + x_edge_partitions = x_edge.split( + self.activation_checkpoint_chunk_size, dim=0 + ) - edge_index_partitions = edge_index.split( - self.activation_checkpoint_chunk_size, dim=1 - ) - wigner_inv_partitions = wigner_and_M_mapping_inv.split( - self.activation_checkpoint_chunk_size, dim=0 - ) - edge_distance_parititons = edge_distance.split( - self.activation_checkpoint_chunk_size, dim=0 - ) - x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) + for idx in range(len(edge_index_partitions)): + x = torch.utils.checkpoint.checkpoint( + self.forward_chunk, + x, + x_edge_partitions[idx], + edge_distance_parititons[idx], + edge_index_partitions[idx], + wigner_inv_partitions[idx], + node_offset, + use_reentrant=False, + ) - for idx in range(len(edge_index_partitions)): - x = torch.utils.checkpoint.checkpoint( - self.forward_chunk, - x, - x_edge_partitions[idx], - edge_distance_parititons[idx], - edge_index_partitions[idx], - wigner_inv_partitions[idx], - node_offset, - use_reentrant=False, + if gp_utils.initialized(): + x = gp_utils.gather_from_model_parallel_region_sum_grad( + x, natoms, gloo_backend=gloo_backend ) return x diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index de4d81d1b5..64eef7be78 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -14,9 +14,8 @@ from fairchem.core.common import gp_utils from fairchem.core.common.gp_utils import ( - gather_from_model_parallel_region, - gather_from_model_parallel_region_sum_grad, scatter_to_model_parallel_region, + gather_from_model_parallel_region_sum_grad ) from fairchem.core.common.test_utils import ( PGConfig, @@ -24,6 +23,8 @@ spawn_multi_process, ) +from fairchem.core.common.gp_utils import size_list_fn + def _dummy_call(x): return x @@ -83,9 +84,9 @@ def test_scatter_tensors( assert torch.equal(out, expected_out) -def scatter_gather_fn(input: torch.Tensor, dim: int = 0): - x = scatter_to_model_parallel_region(input, dim) - return gather_from_model_parallel_region(x, dim) +def scatter_gather_fn(input: torch.Tensor): + x = scatter_to_model_parallel_region(input) + return gather_from_model_parallel_region_sum_grad(x, input.shape[0], True) @pytest.mark.parametrize( @@ -133,7 +134,7 @@ def test_gather_tensors( def scatter_bwd_test(): rank = dist.get_rank() x_full = torch.tensor([2, 3, 5, 7], requires_grad=True, dtype=torch.float) - x = scatter_to_model_parallel_region(x_full, 0) + x = scatter_to_model_parallel_region(x_full) energy_part = x.prod() ** 2 @@ -202,68 +203,12 @@ def test_scatter_bwd(): compare_and_assert_dict(expected_output[results["gp_rank"]], results) -def gather_bwd_test(rank=-1): - if rank < 0: - rank = dist.get_rank() - x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) - x_full = gather_from_model_parallel_region(x, 0) - else: - x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) - x_other = torch.tensor([(1 - rank) + 2], requires_grad=True, dtype=torch.float) - x_full = torch.cat([x, x_other]) if rank == 0 else torch.cat([x_other, x]) - - energy_part = (x_full.prod() + rank + 1) ** 2 - - forces_part = torch.autograd.grad( - [energy_part], - [x], - create_graph=True, - )[0] - - dforces_dinput_part = torch.autograd.grad( - [forces_part], - [x], - create_graph=True, - )[0] - - return { - "gp_rank": rank, - "energy": energy_part.detach(), - "forces": forces_part.detach(), - "dforces_dinput": dforces_dinput_part.detach(), - } - - -def test_gather_bwd(): - # A | B - # E_0 = (A*B +1)^2 , E_1 = (A*B+2)^2 - # = 49 = 64 - # dL_0/dA = 2*A*B^2+2*B = 42 - # dL_1/dB = 2*A^2*B+4*A = 32 - # dL_0/dB and dL_1/dA are not used! see test_gather_sum_bwd!! - # d^2L_1/dA^2 = 2*B^2 = 18 - # d^2L_1/dB^2 = 2*A^2 = 8 - - non_gp_results_by_gp_rank = {0: gather_bwd_test(0), 1: gather_bwd_test(1)} - - config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - all_rank_results = spawn_multi_process( - config, - gather_bwd_test, - init_pg_and_rank_and_launch_test, - ) - - for rank_results in all_rank_results: - compare_and_assert_dict( - non_gp_results_by_gp_rank[rank_results["gp_rank"]], rank_results - ) - def gather_sum_bwd_test(rank=-1): if rank < 0: rank = dist.get_rank() x = torch.tensor([rank + 2], requires_grad=True, dtype=torch.float) - x_full = gather_from_model_parallel_region_sum_grad(x, 0) + x_full = gather_from_model_parallel_region_sum_grad(x, gp_utils.get_gp_world_size(),True) energy = (x_full.prod() + rank + 1) ** 2 # sum energy = gp_utils.reduce_from_model_parallel_region(energy) @@ -345,7 +290,7 @@ def scatter_prod_reduce(all_inputs): x_full = all_inputs.clone() - x = scatter_to_model_parallel_region(x_full, dim=0) + 0 + x = scatter_to_model_parallel_region(x_full) + 0 # BE VERY CAREFUL, inside of this context do not use any variables # in other partitions, their gradient will not propagate! if rank == 0: @@ -383,29 +328,13 @@ def test_scatter_prod_reduce(): output_tensor[key], expected_output[key] ).all(), f"Failed closeness check for {key}" - -def layer(x, target_rank): - rank = dist.get_rank() - - x_full = gather_from_model_parallel_region(x, 0) - x_prod = x_full.prod() - # backward graphs need to be same operation wise - # otherwise might miss a dist sync - if rank == target_rank: - x = x * 0 + x_prod - else: - x = x * 0 + x_prod * 0.0 + (rank + 1) - return x - - def embeddings_and_graph_init(atomic_numbers, edge_index): if gp_utils.initialized(): - node_partition = gp_utils.scatter_to_model_parallel_region( - torch.arange(len(atomic_numbers)).to(atomic_numbers.device) - ) - assert ( - node_partition.numel() > 0 - ), "Looks like there is no atoms in this graph paralell partition. Cannot proceed" + node_partition = torch.split( + torch.arange(atomic_numbers.shape[0]), + size_list_fn(atomic_numbers.shape[0], gp_utils.get_gp_world_size()) + )[gp_utils.get_gp_rank()] + edge_partition = torch.where( torch.logical_and( edge_index[1] >= node_partition.min(), @@ -416,32 +345,31 @@ def embeddings_and_graph_init(atomic_numbers, edge_index): graph_dict = { "node_offset": node_partition.min().item(), "edge_index": edge_index[:, edge_partition], + "natoms":atomic_numbers.shape[0] } - node_embeddings = atomic_numbers[node_partition] else: graph_dict = { "node_offset": 0, "edge_index": edge_index, + "natoms":atomic_numbers.shape[0] } - node_embeddings = atomic_numbers - return node_embeddings, graph_dict + return atomic_numbers, graph_dict # test for one rank to return a product and rest return 0 -def simple_layer(x, edge_index, node_offset, n=3): +def simple_layer(x, edge_index, node_offset, natoms, n=3): + x_source = x[edge_index[0]] + x_target = x[edge_index[1]] if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, dim=0) - x_source = x_full[edge_index[0]] - x_target = x_full[edge_index[1]] dp_rank = gp_utils.get_dp_rank() + local_atoms=size_list_fn(natoms, gp_utils.get_gp_world_size())[gp_utils.get_gp_rank()] else: - x_source = x[edge_index[0]] - x_target = x[edge_index[1]] if dist.is_initialized(): dp_rank = dist.get_rank() else: dp_rank = 0.0 + local_atoms=x.shape[0] # make sure different ddp ranks have different outputs # similar to seeing diffferent data batches @@ -451,14 +379,19 @@ def simple_layer(x, edge_index, node_offset, n=3): edge_embeddings = (x_source + 1).pow(n) * (x_target + 1).pow(n) new_node_embedding = torch.zeros( - (x.shape[0],) + edge_embeddings.shape[1:], + (local_atoms,) + edge_embeddings.shape[1:], dtype=edge_embeddings.dtype, device=edge_embeddings.device, ) new_node_embedding.index_add_(0, edge_index[1] - node_offset, edge_embeddings) - return new_node_embedding + if gp_utils.initialized(): + return gp_utils.gather_from_model_parallel_region_sum_grad( + new_node_embedding, natoms, gloo_backend=True + ) + else: + return new_node_embedding class SimpleNet(nn.Module): @@ -471,7 +404,7 @@ def forward(self, atomic_numbers, edge_index): node_embeddings, graph_dict = embeddings_and_graph_init( atomic_numbers, edge_index ) - + all_node_embeddings = [node_embeddings] # store for debugging for _ in range(self.nlayers): all_node_embeddings.append( @@ -479,12 +412,18 @@ def forward(self, atomic_numbers, edge_index): all_node_embeddings[-1], graph_dict["edge_index"], node_offset=graph_dict["node_offset"], + natoms=graph_dict["natoms"], n=self.n, ) ) final_node_embeddings = all_node_embeddings[-1] + if gp_utils.initialized(): + local_atoms=size_list_fn(graph_dict['natoms'], gp_utils.get_gp_world_size())[gp_utils.get_gp_rank()] + final_node_embeddings=final_node_embeddings[graph_dict["node_offset"]:graph_dict["node_offset"]+local_atoms ] + + # only 1 system energy_part = torch.zeros( 1, device=atomic_numbers.device, dtype=atomic_numbers.dtype