Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Domino #279

Closed
wants to merge 17 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
@@ -202,7 +202,7 @@ def build_grad_buffers(
return fp32_grad_buffers, contiguous_buffer_f32_gradients

def backward(self, loss: torch.Tensor):
if isinstance(loss, tuple):
if not isinstance(loss, torch.Tensor):
assert 1 == 1
raise NotImplementedError("Not implemented yet")

26 changes: 26 additions & 0 deletions src/nanotron/parallel/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict


class AsyncCommBucket:
"""

Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: expected Variable or None (got tuple)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: expected Variable or None (got tuple)
"""

_async_op: Dict[int, "dist.Work"] = {}

@staticmethod
def add(tensor_id: int, work: "dist.Work"):
AsyncCommBucket._async_op[tensor_id] = work

@staticmethod
def get(tensor_id: int):
return AsyncCommBucket._async_op.get(tensor_id)

@staticmethod
def wait(tensor_id: int):
work = AsyncCommBucket._async_op.pop(tensor_id)
work.wait()
8 changes: 6 additions & 2 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@
from typing import Dict, Iterable, Optional, Union

import torch
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
@@ -12,8 +15,6 @@
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

logger = logging.get_logger(__name__)

@@ -83,6 +84,9 @@ def backward(
if grad_accumulator is None:
sum(activations).backward()
else:
# if not isinstance(activations, torch.Tensor):
# raise NotImplementedError("Only support sum of tensors for now")

grad_accumulator.backward(sum(activations))

# TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@

from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
from nanotron.parallel.comm import AsyncCommBucket


class DifferentiableIdentity(torch.autograd.Function):
@@ -42,14 +43,29 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
def forward(
ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool
) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
# ctx.mark_non_differentiable(async_all_reduce)
ctx.async_all_reduce = async_all_reduce

if group.size() == 1:
return tensor

orig_id = id(tensor)
handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce)
# if async_all_reduce:
# handle.wait()
new_id = id(tensor)
assert 1 == 1
assert orig_id == new_id
# if async_all_reduce:
# return tensor, handle
# else:
# return tensor, None
if async_all_reduce:
return tensor, handle
else:
return tensor, None
# AsyncCommBucket.add(tensor, handle)
# AsyncCommBucket.add(id(tensor), handle)
AsyncCommBucket.add(orig_id, handle)

return tensor

@staticmethod
def backward(ctx, grad_output):
10 changes: 9 additions & 1 deletion src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
@@ -597,7 +597,15 @@ def row_linear(
out = F.linear(input, weight, bias)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
# out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
orig_out_id = id(out)
# NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
if async_all_reduce:
from nanotron.parallel.comm import AsyncCommBucket

work = AsyncCommBucket.get(orig_out_id)
assert 1 == 1
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
out = differentiable_reduce_scatter_sum(out, group=group)
2 changes: 1 addition & 1 deletion src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
@@ -293,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
out = out * (~input_mask[..., None])

if self.mode is TensorParallelLinearMode.ALL_REDUCE:
out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=self.pg)
else: