diff --git a/configs/uma/benchmark/uma-speed.yaml b/configs/uma/benchmark/uma-speed.yaml index 5777fe5b23..5a7a71c0c0 100644 --- a/configs/uma/benchmark/uma-speed.yaml +++ b/configs/uma/benchmark/uma-speed.yaml @@ -13,9 +13,9 @@ runner: inference_settings: _target_: fairchem.core.units.mlip_unit.api.inference.InferenceSettings tf32: True - activation_checkpointing: True + activation_checkpointing: False merge_mole: True compile: False - wigner_cuda: True + wigner_cuda: False external_graph_gen: True internal_graph_gen_version: 2 \ No newline at end of file diff --git a/src/fairchem/core/common/gp_utils.py b/src/fairchem/core/common/gp_utils.py index 48d2226cb8..c660c8336e 100644 --- a/src/fairchem/core/common/gp_utils.py +++ b/src/fairchem/core/common/gp_utils.py @@ -11,7 +11,6 @@ import logging from typing import Any -import numpy as np import torch from torch import distributed as dist from torch.distributed.nn.functional import all_reduce @@ -183,23 +182,6 @@ def trim_tensor(tensor: torch.Tensor, sizes: torch.Tensor | None = None, dim: in return tensor, sizes -def _tensor_to_split_partitions(tensor: torch.Tensor, dim: int = -1): - group = get_gp_group() - num_parts = dist.get_world_size(group=group) - return [len(part) for part in np.array_split(np.zeros(tensor.size(dim)), num_parts)] - - -def _split_tensor( - tensor: torch.Tensor, - dim: int = -1, - contiguous_chunks: bool = False, -): - tensor_list = torch.split(tensor, _tensor_to_split_partitions(tensor, dim), dim=dim) - if contiguous_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - return tensor_list - - def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor: group = get_gp_group() if ctx: @@ -210,55 +192,74 @@ def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor: def _split(input: torch.Tensor, dim: int = -1) -> torch.Tensor: rank = get_gp_rank() - input_list = _split_tensor(input, dim=dim) - return input_list[rank].clone().contiguous() + group = get_gp_group() + world_size = dist.get_world_size(group=group) + sizes = [ + input.shape[dim] // world_size + + (1 if idx < input.shape[dim] % world_size else 0) + for idx in range(world_size) + ] + return torch.split(input, sizes, dim=dim)[rank] -def _gather(input: torch.Tensor, dim: int = -1) -> torch.Tensor: +def size_list_fn(natoms, world_size): + return [ + natoms // world_size + (1 if idx < natoms % world_size else 0) + for idx in range(world_size) + ] + + +def _gather_with_padding_gloo( + input: torch.Tensor, + size_list, +) -> torch.Tensor: group = get_gp_group() rank = get_gp_rank() world_size = dist.get_world_size(group=group) if world_size == 1: return input - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input + + # gloo does not support all_gather with different sized tensors + slice_size = max(size_list) + all_atoms = torch.zeros( + (slice_size * world_size,) + input.shape[1:], + device=input.device, + dtype=input.dtype, + ) + tensor_list = list(all_atoms.split(slice_size, dim=0)) + if input.shape[0] < slice_size: + input = pad_tensor(input, 0, slice_size) + dist.all_gather(tensor_list, input, group=group) - return torch.cat(tensor_list, dim=dim).contiguous() + tensor_list[rank] = input # pop back in our local copy (requires grad) + + tensor_list = [ + tensor.narrow(0, 0, size) for tensor, size in zip(tensor_list, size_list) + ] + return torch.cat( + tensor_list, + dim=0, + ) -def _gather_with_padding(input: torch.Tensor, dim: int = -1) -> torch.Tensor: +def _gather_with_padding(input: torch.Tensor, size_list) -> torch.Tensor: group = get_gp_group() rank = get_gp_rank() world_size = dist.get_world_size(group=group) if world_size == 1: return input - # Gather sizes - size_list = [ - torch.empty(1, device=input.device, dtype=torch.long) for _ in range(world_size) - ] - size = torch.tensor([input.size(dim)], device=input.device, dtype=torch.long) - size_list[rank] = size - dist.all_gather(size_list, size, group=group) - - # Gather the inputs - max_size = int(max([size.item() for size in size_list])) - input = pad_tensor(input, dim, max_size) - shape = list(input.shape) - shape[dim] = max_size - tensor_list = [ - torch.empty(shape, device=input.device, dtype=input.dtype) - for _ in range(world_size) - ] + all_atoms = torch.zeros( + (sum(size_list),) + input.shape[1:], device=input.device, dtype=input.dtype + ) + tensor_list = list(all_atoms.split(size_list, dim=0)) dist.all_gather(tensor_list, input, group=group) tensor_list[rank] = input # pop back in our local copy (requires grad) - # Trim and cat - return torch.cat( - [tensor.narrow(dim, 0, size) for tensor, size in zip(tensor_list, size_list)], - dim=dim, - ).contiguous() + node_offset = sum(size_list[:rank]) + all_atoms[node_offset : node_offset + input.shape[0]] = input + return all_atoms class CopyToModelParallelRegion(torch.autograd.Function): @@ -298,26 +299,26 @@ def backward(ctx, grad_output: torch.Tensor): class GatherFromModelParallelRegion(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - ctx.save_for_backward(torch.tensor(dim)) - return _gather_with_padding(input, dim) + def forward(ctx, input: torch.Tensor, size_list) -> torch.Tensor: + if dist.get_backend() == "gloo": + return _gather_with_padding_gloo(input, size_list) + return _gather_with_padding(input, size_list) @staticmethod def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors - result = _split(grad_output, dim.item()) - return result, None + result = _split(grad_output, 0) + return result, None, None class GatherFromModelParallelRegionSumGrad(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor: - ctx.save_for_backward(torch.tensor(dim)) - return _gather_with_padding(input, dim) + def forward(ctx, input: torch.Tensor, size_list: int) -> torch.Tensor: + if dist.get_backend() == "gloo": + return _gather_with_padding_gloo(input, size_list) + return _gather_with_padding(input, size_list) @staticmethod def backward(ctx, grad_output: torch.Tensor): - (dim,) = ctx.saved_tensors group = get_gp_group() # use dist internal # does not work # reduced_grad_output = grad_output.clone() @@ -329,8 +330,8 @@ def backward(ctx, grad_output: torch.Tensor): # use functional version instead grad_output = all_reduce(grad_output, group=group) - result = _split(grad_output, dim.item()) - return result, None + result = _split(grad_output, 0) + return result, None, None # Leave forward untouched but upscale the gradient by a factor of gp_group_size @@ -369,18 +370,16 @@ def scatter_to_model_parallel_region( return ScatterToModelParallelRegion.apply(input, dim) -def gather_from_model_parallel_region( - input: torch.Tensor, dim: int = -1 -) -> torch.Tensor: +def gather_from_model_parallel_region(input: torch.Tensor, size_list) -> torch.Tensor: assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return GatherFromModelParallelRegion.apply(input, dim) + return GatherFromModelParallelRegion.apply(input, size_list) def gather_from_model_parallel_region_sum_grad( - input: torch.Tensor, dim: int = -1 + input: torch.Tensor, size_list ) -> torch.Tensor: assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!" - return GatherFromModelParallelRegionSumGrad.apply(input, dim) + return GatherFromModelParallelRegionSumGrad.apply(input, size_list) def scale_backward_grad(input: torch.Tensor) -> torch.Tensor: diff --git a/src/fairchem/core/models/uma/common/rotation.py b/src/fairchem/core/models/uma/common/rotation.py index 6381cefa56..01f5577fa4 100644 --- a/src/fairchem/core/models/uma/common/rotation.py +++ b/src/fairchem/core/models/uma/common/rotation.py @@ -7,37 +7,47 @@ from __future__ import annotations -import logging - import torch -def init_edge_rot_euler_angles(edge_distance_vec): - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) +# TODO: this gives wrong forces in special cases! +class Safeacos(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.acos(x) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + norms = x.pow(2) + grad_input = -grad_output / torch.sqrt(1 - norms) + return torch.where(grad_input.isfinite(), grad_input, 0.0) - # Make sure the atoms are far enough apart - # assert torch.min(edge_vec_0_distance) < 0.0001 - if len(edge_vec_0_distance) > 0 and torch.min(edge_vec_0_distance) < 0.0001: - logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") - # make unit vectors - xyz = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) +# TODO: this gives wrong forces in special cases! +class Safeatan2(torch.autograd.Function): + @staticmethod + def forward(ctx, y, x): + ctx.save_for_backward(y, x) + return torch.atan2(y, x) - # are we standing at the north pole - mask = xyz[:, 1].abs().isclose(xyz.new_ones(1)) + @staticmethod + def backward(ctx, grad_output): + y, x = ctx.saved_tensors + norms = x.pow(2) + y.pow(2) + safe_norms = torch.where(norms == 0.0, 1, norms) + return (x / safe_norms) * grad_output, -(y / safe_norms) * grad_output - # compute alpha and beta + +def init_edge_rot_euler_angles(edge_distance_vec): + xyz = torch.nn.functional.normalize(edge_distance_vec) # latitude (beta) - beta = xyz.new_zeros(xyz.shape[0]) - beta[~mask] = torch.acos(xyz[~mask, 1]) - beta[mask] = torch.acos(xyz[mask, 1]).detach() + beta = Safeacos.apply(xyz[:, 1]) # longitude (alpha) - alpha = torch.zeros_like(beta) - alpha[~mask] = torch.atan2(xyz[~mask, 0], xyz[~mask, 2]) - alpha[mask] = torch.atan2(xyz[mask, 0], xyz[mask, 2]).detach() + alpha = Safeatan2.apply(xyz[:, 0], xyz[:, 2]) # random gamma (roll) gamma = torch.rand_like(alpha) * 2 * torch.pi diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index a46138bb13..4a9821b3bd 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -40,7 +40,7 @@ get_normalization_layer, ) from fairchem.core.models.uma.nn.mole_utils import MOLEInterface -from fairchem.core.models.uma.nn.radial import GaussianSmearing +from fairchem.core.models.uma.nn.radial import GaussianSmearing, PolynomialEnvelope from fairchem.core.models.uma.nn.so3_layers import SO3_Linear from fairchem.core.models.utils.irreps import cg_change_mat, irreps_sum @@ -237,7 +237,6 @@ def __init__( self.mappingReduced, self.SO3_grid, self.edge_channels_list, - self.cutoff, self.norm_type, self.act_type, self.ff_type, @@ -361,9 +360,9 @@ def _generate_graph(self, data_dict): "pbc" in data_dict ), "Since always_use_pbc is False, pbc conditions must be supplied by the input data" pbc = data_dict["pbc"] - assert ( - pbc.all() or (~pbc).all() - ), "We can only accept pbc that is all true or all false" + assert ( + pbc.all() or (~pbc).all() + ), "We can only accept pbc that is all true or all false" logging.debug(f"Using radius graph gen version {self.radius_pbc_version}") graph_dict = generate_graph( data_dict, @@ -498,6 +497,10 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: ) self.log_MOLE_stats() + self.envelope = PolynomialEnvelope(exponent=5) + dist_scaled = graph_dict["edge_distance"] / self.cutoff + edge_envelope = self.envelope(dist_scaled).reshape(-1, 1, 1) + # edge degree embedding with record_function("edge embedding"): edge_distance_embedding = self.distance_expansion( @@ -515,9 +518,10 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: x_message = self.edge_degree_embedding( x_message, x_edge, - graph_dict["edge_distance"], + edge_envelope, graph_dict["edge_index"], wigner_and_M_mapping_inv, + data_dict["atomic_numbers_full"].shape[0], graph_dict["node_offset"], ) @@ -529,10 +533,11 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: x_message = self.blocks[i]( x_message, x_edge, - graph_dict["edge_distance"], + edge_envelope, graph_dict["edge_index"], wigner_and_M_mapping, wigner_and_M_mapping_inv, + data_dict["atomic_numbers_full"].shape[0], sys_node_embedding=sys_node_embedding, node_offset=graph_dict["node_offset"], ) @@ -544,6 +549,7 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: "displacement": displacement, "orig_cell": orig_cell, "batch": data_dict["batch"], + "node_offset": graph_dict["node_offset"], } return out @@ -556,14 +562,16 @@ def _init_gp_partitions(self, graph_dict, atomic_numbers_full): edge_distance = graph_dict["edge_distance"] edge_distance_vec_full = graph_dict["edge_distance_vec"] + world_size = gp_utils.get_gp_world_size() + assert ( + atomic_numbers_full.shape[0] >= world_size + ), "Looks like there is no atoms in this graph paralell partition. Cannot proceed" + node_partition = torch.tensor_split( - torch.arange(len(atomic_numbers_full)).to(atomic_numbers_full.device), - gp_utils.get_gp_world_size(), + torch.arange(atomic_numbers_full.shape[0]).to(atomic_numbers_full.device), + world_size, )[gp_utils.get_gp_rank()] - assert ( - node_partition.numel() > 0 - ), "Looks like there is no atoms in this graph paralell partition. Cannot proceed" edge_partition = torch.where( torch.logical_and( edge_index[1] >= node_partition.min(), @@ -682,13 +690,12 @@ def forward( outputs[energy_key] = {"energy": energy} if self.wrap_property else energy - embeddings = emb["node_embedding"].detach() - if gp_utils.initialized(): - embeddings = gp_utils.gather_from_model_parallel_region(embeddings, dim=0) + if not gp_utils.initialized(): + embeddings = emb["node_embedding"].detach() - outputs["embeddings"] = ( - {"embeddings": embeddings} if self.wrap_property else embeddings - ) + outputs["embeddings"] = ( + {"embeddings": embeddings} if self.wrap_property else embeddings + ) if self.regress_stress: grads = torch.autograd.grad( @@ -816,7 +823,10 @@ def forward(self, data_dict: AtomicData, emb: dict[str, torch.Tensor]): forces = forces.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() if gp_utils.initialized(): - forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + size_list = gp_utils.size_list_fn( + data_dict["atomic_numbers_full"].shape[0], gp_utils.get_gp_world_size() + ) + forces = gp_utils.gather_from_model_parallel_region(forces, size_list) return {"forces": forces} diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 406d5826f7..9a0a4655fd 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -24,7 +24,6 @@ get_normalization_layer, ) from fairchem.core.models.uma.nn.mole import MOLE -from fairchem.core.models.uma.nn.radial import PolynomialEnvelope from fairchem.core.models.uma.nn.so2_layers import SO2_Convolution from fairchem.core.models.uma.nn.so3_layers import SO3_Linear @@ -48,7 +47,6 @@ def __init__( edge_channels_list: list[int], mappingReduced: CoefficientMapping, SO3_grid: SO3_Grid, - cutoff: float, # Enables activation checkpointing of edges in # activation_checkpoint_chunk_size size edge blocks activation_checkpoint_chunk_size: int | None, @@ -109,9 +107,6 @@ def __init__( extra_m0_output_channels=None, ) - self.cutoff = cutoff - self.envelope = PolynomialEnvelope(exponent=5) - self.out_mask = self.SO3_grid["lmax_lmax"].mapping.coefficient_idx( self.lmax, self.mmax ) @@ -120,20 +115,22 @@ def forward( self, x, x_edge, - edge_distance, + edge_envelope, edge_index, wigner_and_M_mapping, wigner_and_M_mapping_inv, + natoms, node_offset: int = 0, ): if self.activation_checkpoint_chunk_size is None: return self.forward_chunk( x, x_edge, - edge_distance, + edge_envelope, edge_index, wigner_and_M_mapping, wigner_and_M_mapping_inv, + natoms, node_offset, ) edge_index_partitions = edge_index.split( @@ -145,7 +142,7 @@ def forward( wigner_inv_partitions = wigner_and_M_mapping_inv.split( self.activation_checkpoint_chunk_size, dim=0 ) - edge_distance_parititons = edge_distance.split( + edge_envelope_parititons = edge_envelope.split( self.activation_checkpoint_chunk_size, dim=0 ) x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) @@ -159,10 +156,11 @@ def forward( self.forward_chunk, x, x_edge_partitions[idx], - edge_distance_parititons[idx], + edge_envelope_parititons[idx], edge_index_partitions[idx], wigner_partitions[idx], wigner_inv_partitions[idx], + natoms, node_offset, ac_mole_start_idx, use_reentrant=False, @@ -178,10 +176,11 @@ def forward_chunk( self, x, x_edge, - edge_distance, + edge_envelope, edge_index, wigner_and_M_mapping, wigner_and_M_mapping_inv, + natoms, node_offset: int = 0, ac_mole_start_idx: int = 0, ): @@ -190,7 +189,8 @@ def forward_chunk( set_mole_ac_start_index(self, ac_mole_start_idx) if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, dim=0) + size_list = gp_utils.size_list_fn(natoms, gp_utils.get_gp_world_size()) + x_full = gp_utils.gather_from_model_parallel_region_sum_grad(x, size_list) x_source = x_full[edge_index[0]] x_target = x_full[edge_index[1]] else: @@ -212,9 +212,7 @@ def forward_chunk( x_message = self.so2_conv_2(x_message, x_edge) # envelope - dist_scaled = edge_distance / self.cutoff - env = self.envelope(dist_scaled) - x_message = x_message * env.view(-1, 1, 1) + x_message = x_message * edge_envelope # Rotate back the irreps x_message = torch.bmm(wigner_and_M_mapping_inv, x_message) @@ -319,7 +317,6 @@ def __init__( mappingReduced: CoefficientMapping, SO3_grid: SO3_Grid, edge_channels_list: list[int], - cutoff: float, norm_type: Literal["layer_norm", "layer_norm_sh", "rms_norm_sh"], act_type: Literal["gate", "s2"], ff_type: Literal["spectral", "grid"], @@ -343,7 +340,6 @@ def __init__( edge_channels_list=edge_channels_list, mappingReduced=mappingReduced, SO3_grid=SO3_grid, - cutoff=cutoff, act_type=act_type, activation_checkpoint_chunk_size=activation_checkpoint_chunk_size, ) @@ -377,6 +373,7 @@ def forward( edge_index, wigner_and_M_mapping, wigner_and_M_mapping_inv, + natoms, sys_node_embedding=None, node_offset: int = 0, ): @@ -394,6 +391,7 @@ def forward( edge_index, wigner_and_M_mapping, wigner_and_M_mapping_inv, + natoms, node_offset, ) x = x + x_res diff --git a/src/fairchem/core/models/uma/nn/embedding_dev.py b/src/fairchem/core/models/uma/nn/embedding_dev.py index b2750d8679..aa19923e89 100644 --- a/src/fairchem/core/models/uma/nn/embedding_dev.py +++ b/src/fairchem/core/models/uma/nn/embedding_dev.py @@ -78,31 +78,69 @@ def forward_chunk( self, x, x_edge, - edge_distance, + edge_envelope, edge_index, wigner_and_M_mapping_inv, + natoms, node_offset=0, ): x_edge_m_0 = self.rad_func(x_edge) - x_edge_m_0 = x_edge_m_0.reshape( + x_edge_m_0 = x_edge_m_0.view( -1, self.m_0_num_coefficients, self.sphere_channels ) - x_edge_m_pad = torch.zeros( - ( - x_edge_m_0.shape[0], - (self.m_all_num_coefficents - self.m_0_num_coefficients), - self.sphere_channels, - ), + + x_edge_embedding = torch.zeros( + x_edge_m_0.shape[0], + x_edge_m_0.shape[1] + + self.m_all_num_coefficents + - self.m_0_num_coefficients, + x_edge_m_0.shape[2], + device=x_edge_m_0.device, + dtype=x_edge_m_0.dtype, + ) + x_edge_embedding[:, : x_edge_m_0.shape[1]] = x_edge_m_0 + + x_edge_embedding = torch.bmm(wigner_and_M_mapping_inv, x_edge_embedding) + + x_edge_embedding = x_edge_embedding * edge_envelope + + # TODO is this needed? + x_edge_embedding = x_edge_embedding.to(x.dtype) + + return x.index_add( + 0, edge_index[1] - node_offset, x_edge_embedding / self.rescale_factor + ) + + def forward_chunk_gp( + self, + x, + x_edge, + edge_envelope, + edge_index, + wigner_and_M_mapping_inv, + natoms, + node_offset=0, + ): + x_edge_m_0 = self.rad_func(x_edge) + x_edge_m_0 = x_edge_m_0.view( + -1, self.m_0_num_coefficients, self.sphere_channels + ) + + x_edge_embedding = torch.zeros( + x_edge_m_0.shape[0], + x_edge_m_0.shape[1] + + self.m_all_num_coefficents + - self.m_0_num_coefficients, + x_edge_m_0.shape[2], device=x_edge_m_0.device, dtype=x_edge_m_0.dtype, ) - x_edge_embedding = torch.cat((x_edge_m_0, x_edge_m_pad), dim=1) + breakpoint() + x_edge_embedding[:, : x_edge_m_0.shape[1]] = x_edge_m_0 + x_edge_embedding = torch.bmm(wigner_and_M_mapping_inv, x_edge_embedding) - # envelope - dist_scaled = edge_distance / self.cutoff - env = self.envelope(dist_scaled) - x_edge_embedding = x_edge_embedding * env.view(-1, 1, 1) + x_edge_embedding = x_edge_embedding * edge_envelope # TODO is this needed? x_edge_embedding = x_edge_embedding.to(x.dtype) @@ -115,18 +153,40 @@ def forward( self, x, x_edge, - edge_distance, + edge_envelope, edge_index, wigner_and_M_mapping_inv, + natoms, node_offset=0, ): if self.activation_checkpoint_chunk_size is None: + # if True or gp_utils.initialized(): + # group = get_gp_group() + # rank = get_gp_rank() + # world_size = dist.get_world_size(group=group) + # size_list=size_list_fn(natoms, world_size) + + # n_chunks=2 + # for chunk_idx in range(n_chunks): + # node_offset = (natoms // n_chunks) * chunk_idx + + # self.forward_chunk_gp( + # x, + # x_edge, + # edge_envelope, + # edge_index, + # wigner_and_M_mapping_inv, + # natoms, + # node_offset, + # ) + return self.forward_chunk( x, x_edge, - edge_distance, + edge_envelope, edge_index, wigner_and_M_mapping_inv, + natoms, node_offset, ) @@ -136,7 +196,7 @@ def forward( wigner_inv_partitions = wigner_and_M_mapping_inv.split( self.activation_checkpoint_chunk_size, dim=0 ) - edge_distance_parititons = edge_distance.split( + edge_envelope_parititons = edge_envelope.split( self.activation_checkpoint_chunk_size, dim=0 ) x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) @@ -146,9 +206,10 @@ def forward( self.forward_chunk, x, x_edge_partitions[idx], - edge_distance_parititons[idx], + edge_envelope_parititons[idx], edge_index_partitions[idx], wigner_inv_partitions[idx], + natoms, node_offset, use_reentrant=False, ) diff --git a/src/fairchem/core/models/uma/nn/so2_layers.py b/src/fairchem/core/models/uma/nn/so2_layers.py index 677fa6e0c3..a0e88a4c73 100644 --- a/src/fairchem/core/models/uma/nn/so2_layers.py +++ b/src/fairchem/core/models/uma/nn/so2_layers.py @@ -65,9 +65,9 @@ def __init__( def forward(self, x_m: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x_m = self.fc(x_m) - x_r, x_i = x_m.split(self.out_channels_half, dim=2) - x_r_0, x_r_1 = x_r.split(1, dim=1) - x_i_0, x_i_1 = x_i.split(1, dim=1) + x_r_0, x_i_0, x_r_1, x_i_1 = x_m.reshape( + x_m.shape[0], -1, self.out_channels_half + ).split(1, dim=1) x_m_r = x_r_0 - x_i_1 # x_r[:, 0] - x_i[:, 1] x_m_i = x_r_1 + x_i_0 # x_r[:, 1] + x_i[:, 0] return ( diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index de4d81d1b5..cbe4287c87 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -85,7 +85,7 @@ def test_scatter_tensors( def scatter_gather_fn(input: torch.Tensor, dim: int = 0): x = scatter_to_model_parallel_region(input, dim) - return gather_from_model_parallel_region(x, dim) + return gather_from_model_parallel_region(x, input.shape[dim] ,dim) @pytest.mark.parametrize(