Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/uma/benchmark/uma-speed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ runner:
inference_settings:
_target_: fairchem.core.units.mlip_unit.api.inference.InferenceSettings
tf32: True
activation_checkpointing: True
activation_checkpointing: False
merge_mole: True
compile: False
wigner_cuda: True
wigner_cuda: False
external_graph_gen: True
internal_graph_gen_version: 2
129 changes: 64 additions & 65 deletions src/fairchem/core/common/gp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import logging
from typing import Any

import numpy as np
import torch
from torch import distributed as dist
from torch.distributed.nn.functional import all_reduce
Expand Down Expand Up @@ -183,23 +182,6 @@ def trim_tensor(tensor: torch.Tensor, sizes: torch.Tensor | None = None, dim: in
return tensor, sizes


def _tensor_to_split_partitions(tensor: torch.Tensor, dim: int = -1):
group = get_gp_group()
num_parts = dist.get_world_size(group=group)
return [len(part) for part in np.array_split(np.zeros(tensor.size(dim)), num_parts)]


def _split_tensor(
tensor: torch.Tensor,
dim: int = -1,
contiguous_chunks: bool = False,
):
tensor_list = torch.split(tensor, _tensor_to_split_partitions(tensor, dim), dim=dim)
if contiguous_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list


def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor:
group = get_gp_group()
if ctx:
Expand All @@ -210,55 +192,74 @@ def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor:

def _split(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
rank = get_gp_rank()
input_list = _split_tensor(input, dim=dim)
return input_list[rank].clone().contiguous()
group = get_gp_group()
world_size = dist.get_world_size(group=group)
sizes = [
input.shape[dim] // world_size
+ (1 if idx < input.shape[dim] % world_size else 0)
for idx in range(world_size)
]
return torch.split(input, sizes, dim=dim)[rank]


def _gather(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
def size_list_fn(natoms, world_size):
return [
natoms // world_size + (1 if idx < natoms % world_size else 0)
for idx in range(world_size)
]


def _gather_with_padding_gloo(
input: torch.Tensor,
size_list,
) -> torch.Tensor:
group = get_gp_group()
rank = get_gp_rank()
world_size = dist.get_world_size(group=group)
if world_size == 1:
return input
tensor_list = [torch.empty_like(input) for _ in range(world_size)]
tensor_list[rank] = input

# gloo does not support all_gather with different sized tensors
slice_size = max(size_list)
all_atoms = torch.zeros(
(slice_size * world_size,) + input.shape[1:],
device=input.device,
dtype=input.dtype,
)
tensor_list = list(all_atoms.split(slice_size, dim=0))
if input.shape[0] < slice_size:
input = pad_tensor(input, 0, slice_size)

dist.all_gather(tensor_list, input, group=group)
return torch.cat(tensor_list, dim=dim).contiguous()
tensor_list[rank] = input # pop back in our local copy (requires grad)

tensor_list = [
tensor.narrow(0, 0, size) for tensor, size in zip(tensor_list, size_list)
]
return torch.cat(
tensor_list,
dim=0,
)


def _gather_with_padding(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
def _gather_with_padding(input: torch.Tensor, size_list) -> torch.Tensor:
group = get_gp_group()
rank = get_gp_rank()
world_size = dist.get_world_size(group=group)
if world_size == 1:
return input

# Gather sizes
size_list = [
torch.empty(1, device=input.device, dtype=torch.long) for _ in range(world_size)
]
size = torch.tensor([input.size(dim)], device=input.device, dtype=torch.long)
size_list[rank] = size
dist.all_gather(size_list, size, group=group)

# Gather the inputs
max_size = int(max([size.item() for size in size_list]))
input = pad_tensor(input, dim, max_size)
shape = list(input.shape)
shape[dim] = max_size
tensor_list = [
torch.empty(shape, device=input.device, dtype=input.dtype)
for _ in range(world_size)
]
all_atoms = torch.zeros(
(sum(size_list),) + input.shape[1:], device=input.device, dtype=input.dtype
)
tensor_list = list(all_atoms.split(size_list, dim=0))

dist.all_gather(tensor_list, input, group=group)
tensor_list[rank] = input # pop back in our local copy (requires grad)

# Trim and cat
return torch.cat(
[tensor.narrow(dim, 0, size) for tensor, size in zip(tensor_list, size_list)],
dim=dim,
).contiguous()
node_offset = sum(size_list[:rank])
all_atoms[node_offset : node_offset + input.shape[0]] = input
return all_atoms


class CopyToModelParallelRegion(torch.autograd.Function):
Expand Down Expand Up @@ -298,26 +299,26 @@ def backward(ctx, grad_output: torch.Tensor):

class GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor:
ctx.save_for_backward(torch.tensor(dim))
return _gather_with_padding(input, dim)
def forward(ctx, input: torch.Tensor, size_list) -> torch.Tensor:
if dist.get_backend() == "gloo":
return _gather_with_padding_gloo(input, size_list)
return _gather_with_padding(input, size_list)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
(dim,) = ctx.saved_tensors
result = _split(grad_output, dim.item())
return result, None
result = _split(grad_output, 0)
return result, None, None


class GatherFromModelParallelRegionSumGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor:
ctx.save_for_backward(torch.tensor(dim))
return _gather_with_padding(input, dim)
def forward(ctx, input: torch.Tensor, size_list: int) -> torch.Tensor:
if dist.get_backend() == "gloo":
return _gather_with_padding_gloo(input, size_list)
return _gather_with_padding(input, size_list)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
(dim,) = ctx.saved_tensors
group = get_gp_group()
# use dist internal # does not work
# reduced_grad_output = grad_output.clone()
Expand All @@ -329,8 +330,8 @@ def backward(ctx, grad_output: torch.Tensor):
# use functional version instead
grad_output = all_reduce(grad_output, group=group)

result = _split(grad_output, dim.item())
return result, None
result = _split(grad_output, 0)
return result, None, None


# Leave forward untouched but upscale the gradient by a factor of gp_group_size
Expand Down Expand Up @@ -369,18 +370,16 @@ def scatter_to_model_parallel_region(
return ScatterToModelParallelRegion.apply(input, dim)


def gather_from_model_parallel_region(
input: torch.Tensor, dim: int = -1
) -> torch.Tensor:
def gather_from_model_parallel_region(input: torch.Tensor, size_list) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return GatherFromModelParallelRegion.apply(input, dim)
return GatherFromModelParallelRegion.apply(input, size_list)


def gather_from_model_parallel_region_sum_grad(
input: torch.Tensor, dim: int = -1
input: torch.Tensor, size_list
) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return GatherFromModelParallelRegionSumGrad.apply(input, dim)
return GatherFromModelParallelRegionSumGrad.apply(input, size_list)


def scale_backward_grad(input: torch.Tensor) -> torch.Tensor:
Expand Down
50 changes: 30 additions & 20 deletions src/fairchem/core/models/uma/common/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,47 @@

from __future__ import annotations

import logging

import torch


def init_edge_rot_euler_angles(edge_distance_vec):
edge_vec_0 = edge_distance_vec
edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1))
# TODO: this gives wrong forces in special cases!
class Safeacos(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.acos(x)

@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
norms = x.pow(2)
grad_input = -grad_output / torch.sqrt(1 - norms)
return torch.where(grad_input.isfinite(), grad_input, 0.0)

# Make sure the atoms are far enough apart
# assert torch.min(edge_vec_0_distance) < 0.0001
if len(edge_vec_0_distance) > 0 and torch.min(edge_vec_0_distance) < 0.0001:
logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}")

# make unit vectors
xyz = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))
# TODO: this gives wrong forces in special cases!
class Safeatan2(torch.autograd.Function):
@staticmethod
def forward(ctx, y, x):
ctx.save_for_backward(y, x)
return torch.atan2(y, x)

# are we standing at the north pole
mask = xyz[:, 1].abs().isclose(xyz.new_ones(1))
@staticmethod
def backward(ctx, grad_output):
y, x = ctx.saved_tensors
norms = x.pow(2) + y.pow(2)
safe_norms = torch.where(norms == 0.0, 1, norms)
return (x / safe_norms) * grad_output, -(y / safe_norms) * grad_output

# compute alpha and beta

def init_edge_rot_euler_angles(edge_distance_vec):
xyz = torch.nn.functional.normalize(edge_distance_vec)

# latitude (beta)
beta = xyz.new_zeros(xyz.shape[0])
beta[~mask] = torch.acos(xyz[~mask, 1])
beta[mask] = torch.acos(xyz[mask, 1]).detach()
beta = Safeacos.apply(xyz[:, 1])

# longitude (alpha)
alpha = torch.zeros_like(beta)
alpha[~mask] = torch.atan2(xyz[~mask, 0], xyz[~mask, 2])
alpha[mask] = torch.atan2(xyz[mask, 0], xyz[mask, 2]).detach()
alpha = Safeatan2.apply(xyz[:, 0], xyz[:, 2])

# random gamma (roll)
gamma = torch.rand_like(alpha) * 2 * torch.pi
Expand Down
Loading