Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 142 additions & 188 deletions src/fairchem/core/common/gp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@

import contextlib
import logging
from typing import Any
import threading

import numpy as np
import torch
from torch import distributed as dist
from torch.distributed.nn.functional import all_reduce
from torch.distributed.nn.functional import all_reduce, reduce_scatter

"""
Functions to support graph parallel training.
Expand All @@ -27,6 +26,32 @@
_GRAPH_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP = None

_tls = threading.local()


def pad_input(input: torch.Tensor, padded_size: int):
# pad using functional
# if input.shape[0]!=padded_size:
# input=torch.nn.functional.pad(input,(0,0,0,0,0,1)).contiguous()

# pad using manual tensor cat
if input.shape[0] != padded_size:
input = torch.cat(
[
input,
torch.zeros(
(padded_size - input.shape[0], *input.shape[1:]),
device=input.device,
dtype=input.dtype,
),
],
dim=0,
)

assert input.shape[0] == padded_size

return input


def ensure_div(a: int, b: int) -> None:
assert a % b == 0
Expand Down Expand Up @@ -149,190 +174,153 @@ def get_gp_world_size() -> int:
########## DIST METHODS ##########


def pad_tensor(
tensor: torch.Tensor, dim: int = -1, target_size: int | None = None
) -> torch.Tensor:
size = tensor.size(dim)
if target_size is None:
world_size = get_gp_world_size()
pad_size = 0 if size % world_size == 0 else world_size - size % world_size
else:
pad_size = target_size - size
if pad_size == 0:
return tensor
pad_shape = list(tensor.shape)
pad_shape[dim] = pad_size
padding = torch.empty(pad_shape, device=tensor.device, dtype=tensor.dtype)
return torch.cat([tensor, padding], dim=dim)


def trim_tensor(tensor: torch.Tensor, sizes: torch.Tensor | None = None, dim: int = 0):
size = tensor.size(dim)
world_size = get_gp_world_size()
if size % world_size == 0:
return tensor, sizes
trim_size = size - size % world_size
if dim == 0:
tensor = tensor[:trim_size]
elif dim == 1:
tensor = tensor[:, :trim_size]
else:
raise ValueError
if sizes is not None:
sizes[-1] = sizes[-1] - size % world_size
return tensor, sizes


def _tensor_to_split_partitions(tensor: torch.Tensor, dim: int = -1):
group = get_gp_group()
num_parts = dist.get_world_size(group=group)
return [len(part) for part in np.array_split(np.zeros(tensor.size(dim)), num_parts)]


def _split_tensor(
tensor: torch.Tensor,
dim: int = -1,
contiguous_chunks: bool = False,
):
tensor_list = torch.split(tensor, _tensor_to_split_partitions(tensor, dim), dim=dim)
if contiguous_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list


def _reduce(ctx: Any, input: torch.Tensor) -> torch.Tensor:
group = get_gp_group()
if ctx:
ctx.mark_dirty(input)
dist.all_reduce(input, group=group)
return input


def _split(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
rank = get_gp_rank()
input_list = _split_tensor(input, dim=dim)
return input_list[rank].clone().contiguous()


def _gather(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
group = get_gp_group()
rank = get_gp_rank()
world_size = dist.get_world_size(group=group)
if world_size == 1:
return input
tensor_list = [torch.empty_like(input) for _ in range(world_size)]
tensor_list[rank] = input
dist.all_gather(tensor_list, input, group=group)
return torch.cat(tensor_list, dim=dim).contiguous()


def _gather_with_padding(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
group = get_gp_group()
rank = get_gp_rank()
world_size = dist.get_world_size(group=group)
if world_size == 1:
return input

# Gather sizes
size_list = [
torch.empty(1, device=input.device, dtype=torch.long) for _ in range(world_size)
]
size = torch.tensor([input.size(dim)], device=input.device, dtype=torch.long)
size_list[rank] = size
dist.all_gather(size_list, size, group=group)

# Gather the inputs
max_size = int(max([size.item() for size in size_list]))
input = pad_tensor(input, dim, max_size)
shape = list(input.shape)
shape[dim] = max_size
tensor_list = [
torch.empty(shape, device=input.device, dtype=input.dtype)
for _ in range(world_size)
]

dist.all_gather(tensor_list, input, group=group)
tensor_list[rank] = input # pop back in our local copy (requires grad)

# Trim and cat
return torch.cat(
[tensor.narrow(dim, 0, size) for tensor, size in zip(tensor_list, size_list)],
dim=dim,
).contiguous()

def size_list_fn(size, parts):
return [size // parts + (1 if idx < size % parts else 0) for idx in range(parts)]

class CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
return input

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return _reduce(None, grad_output)
def reduce_from_model_parallel_region(input: torch.Tensor) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return ReduceFromModelParallelRegion.apply(input)


class ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
@torch.compiler.disable
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
# return _reduce(ctx, input) # this operates in place
return all_reduce(input, group=get_gp_group()) # this operats out of place

@staticmethod
@torch.compiler.disable
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return grad_output


def scatter_to_model_parallel_region(input: torch.Tensor) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return ScatterToModelParallelRegion.apply(input)


# this returns the values in place
class ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
@torch.compiler.disable
def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor:
result = _split(input, dim)
ctx.save_for_backward(torch.tensor(dim))
return result
ctx.split_sizes = size_list_fn(input.shape[0], get_gp_world_size())
return input.split(ctx.split_sizes)[get_gp_rank()]

@staticmethod
@torch.compiler.disable
def backward(ctx, grad_output: torch.Tensor):
(dim,) = ctx.saved_tensors
return _gather_with_padding(grad_output.clone(), dim.item()), None
return gather_from_model_parallel_region_sum_grad(
grad_output, sum(ctx.split_sizes), False
)


class GatherFromModelParallelRegion(torch.autograd.Function):
def gather_from_model_parallel_region_sum_grad(
input: torch.Tensor,
natoms: int,
gloo_backend: bool,
) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
world_size = get_gp_world_size()
size_list = size_list_fn(natoms, world_size)

input = pad_input(
input, natoms // world_size + (1 if natoms % world_size != 0 else 0)
)

if gloo_backend:
tensor_list_w_padding = GatherFromModelParallelRegionSumGradPaddedGLOO.apply(
input, False
)
else:
# tensor_list_w_padding = all_gather(input, group=get_gp_group())
tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(
input, False
)

return torch.cat(
[
t.narrow(0, 0, s) if t.shape[0] != s else t
for t, s in zip(tensor_list_w_padding, size_list)
],
dim=0,
)


def gather_from_model_parallel_region_sum_grad_async(
input: torch.Tensor, natoms: int
) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
world_size = get_gp_world_size()

input = pad_input(
input, natoms // world_size + (1 if natoms % world_size != 0 else 0)
)

tensor_list_w_padding = GatherFromModelParallelRegionSumGradPadded.apply(
input, True
)
handle = _tls.handle
return tensor_list_w_padding, handle


class GatherFromModelParallelRegionSumGradPadded(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor:
ctx.save_for_backward(torch.tensor(dim))
return _gather_with_padding(input, dim)
@torch.compiler.disable
def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor:
ctx.rank = get_gp_rank()
ctx.group = get_gp_group()
tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())]
if async_op:
_tls.handle = dist.all_gather(
tensor_list, input, group=ctx.group, async_op=async_op
)
else:
dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op)
return tuple(tensor_list)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
(dim,) = ctx.saved_tensors
result = _split(grad_output, dim.item())
return result, None
@torch.compiler.disable
def backward(ctx, *grad_outputs):
local_grad_output = grad_outputs[ctx.rank]
output_tensor = torch.empty_like(local_grad_output)
return reduce_scatter(output_tensor, grad_outputs, group=ctx.group), None


class GatherFromModelParallelRegionSumGrad(torch.autograd.Function):
class GatherFromModelParallelRegionSumGradPaddedGLOO(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, dim: int = -1) -> torch.Tensor:
ctx.save_for_backward(torch.tensor(dim))
return _gather_with_padding(input, dim)
@torch.compiler.disable
def forward(ctx, input: torch.Tensor, async_op: bool) -> torch.Tensor:
ctx.rank = get_gp_rank()
ctx.group = get_gp_group()
ctx.shape = input.shape
tensor_list = [torch.empty_like(input) for _ in range(get_gp_world_size())]
if async_op:
_tls.handle = dist.all_gather(
tensor_list, input, group=ctx.group, async_op=async_op
)
else:
dist.all_gather(tensor_list, input, group=ctx.group, async_op=async_op)
return tuple(tensor_list)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
(dim,) = ctx.saved_tensors
group = get_gp_group()
# use dist internal # does not work
# reduced_grad_output = grad_output.clone()
# dist.all_reduce(
# reduced_grad_output, group=group
# ) # This is an inplace operation
# grad_output = reduced_grad_output

# use functional version instead
grad_output = all_reduce(grad_output, group=group)

result = _split(grad_output, dim.item())
@torch.compiler.disable
def backward(ctx, *grad_outputs):
grad_output = all_reduce(torch.cat(grad_outputs, dim=0), group=ctx.group)
ctx.padded_size = grad_outputs[0].shape[0]
result = grad_output[
ctx.padded_size * ctx.rank : ctx.padded_size * ctx.rank + ctx.shape[0]
]
return result, None


def scale_backward_grad(input: torch.Tensor) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return ScaleBackwardGrad.apply(input)


# Leave forward untouched but upscale the gradient by a factor of gp_group_size
# DDP reduces a mean across the loss, if we have gp_group_size=2 and 6 ranks
# that means we do (a_1+a_2+a_3+b_1+b_2+b_3)/6 in ddp mean. This gets us the
Expand All @@ -344,45 +332,11 @@ def backward(ctx, grad_output: torch.Tensor):
# avoid over head communication
class ScaleBackwardGrad(torch.autograd.Function):
@staticmethod
@torch.compiler.disable
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
return input

@staticmethod
@torch.compiler.disable
def backward(ctx, grad_output: torch.Tensor):
return dist.get_world_size(get_gp_group()) * grad_output


def copy_to_model_parallel_region(input: torch.Tensor) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return CopyToModelParallelRegion.apply(input)


def reduce_from_model_parallel_region(input: torch.Tensor) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return ReduceFromModelParallelRegion.apply(input)


def scatter_to_model_parallel_region(
input: torch.Tensor, dim: int = -1
) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return ScatterToModelParallelRegion.apply(input, dim)


def gather_from_model_parallel_region(
input: torch.Tensor, dim: int = -1
) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return GatherFromModelParallelRegion.apply(input, dim)


def gather_from_model_parallel_region_sum_grad(
input: torch.Tensor, dim: int = -1
) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return GatherFromModelParallelRegionSumGrad.apply(input, dim)


def scale_backward_grad(input: torch.Tensor) -> torch.Tensor:
assert initialized(), "Cannot use graph parallel with initializing gp group, must call setup_gp from gp_utils.py!"
return ScaleBackwardGrad.apply(input)
Loading