Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Hide 75% of the communication in tensor parallelism using DoMiNo #292

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d7bf8be
first draft of domino forward pass
xrsrke Jan 29, 2025
3803b19
support the backward pass
xrsrke Jan 30, 2025
d765fd5
the first draft for bwd overlapping
xrsrke Jan 31, 2025
9924608
add backward pass overlapping
xrsrke Feb 3, 2025
d6bc8da
fix some ops dont execute in the bwd pass
xrsrke Feb 4, 2025
93b2f10
fix can't find an ops in fwd
xrsrke Feb 5, 2025
31db05d
partially overlapping bwd pass
xrsrke Feb 5, 2025
23f2108
fix stream not sync
xrsrke Feb 10, 2025
3e3ae8c
exp2a1c7c2_like_exp2a1c1_domini_llama3_3b_with_tp8_and_seqlen4096_and…
xrsrke Feb 21, 2025
c261488
refactor
xrsrke Feb 21, 2025
841c7d6
add tests and more refactoring
xrsrke Feb 21, 2025
8a0f993
add domino config, fix breaks in _RowLinearAsyncCommunication
xrsrke Feb 24, 2025
a61d2df
add bwd.layer_mlp_x_batch_1 as async op
xrsrke Feb 25, 2025
06e17bc
- add cuda stream sync after attn_output0[work]
xrsrke Feb 25, 2025
8d44942
wait default_stream instead of current_stream
xrsrke Feb 25, 2025
aa77e6c
put torch.cuda.synchronize() everywhere
xrsrke Feb 25, 2025
76b5f9a
only bwd.layer_attn_{}_batch_0 as non async
xrsrke Feb 25, 2025
fe7ee7e
exp7a7_like_exp7a6_but_remove_fwd_pass_cuda_syncronization
xrsrke Feb 25, 2025
e0a9bd0
remove torch.cuda.synchronize in WaitComm.backward
xrsrke Feb 26, 2025
a772ff0
add back torch.cuda.synchronize in WaitComm.backward and small refactors
xrsrke Feb 26, 2025
543ef56
add ctx.comm_stream.wait_stream(torch.cuda.default_stream()) to WaitC…
xrsrke Feb 27, 2025
36c9980
exp7a10_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remov…
xrsrke Feb 27, 2025
613eb16
remove comments and add typing
xrsrke Feb 28, 2025
600f01a
remove explicite async_op arg
xrsrke Feb 28, 2025
320e55d
Merge remote-tracking branch 'origin/main' into domino_revert_from_fi…
xrsrke Mar 5, 2025
29a8914
pass stream amanger to llama's modules
xrsrke Mar 7, 2025
75abb32
move domino's assert args to config
xrsrke Mar 7, 2025
da4220c
add retrieving async distributed handle from comm bucket instead of r…
xrsrke Mar 7, 2025
d7a636f
small refactor
xrsrke Mar 7, 2025
d3d8c10
add CudaStreamManager.init_default_comm_stream and fix domino test
xrsrke Mar 7, 2025
74d415c
removing op_name in the forward pass by adding OpNameContext
xrsrke Mar 7, 2025
08a4472
add CudaStreamManager as context
xrsrke Mar 8, 2025
684b1b9
small refactor
xrsrke Mar 8, 2025
f8e8b1f
Reverting repository to commit 74d415c1c02b9463214fb46db060c0efbfa5a0e4
xrsrke Mar 10, 2025
61ff007
add todos
xrsrke Mar 10, 2025
9039ce2
add todo
xrsrke Mar 10, 2025
62fb3b2
add todos
xrsrke Mar 10, 2025
7c7b6f7
add todo and undo torch_nn
xrsrke Mar 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
@@ -10,3 +10,6 @@

CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"


CUDA_STREAMS = {}
2 changes: 2 additions & 0 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
@@ -482,7 +482,9 @@ def get_profiler(config: Config):
on_trace_ready=on_trace_ready,
# record_shapes=True,
# profile_memory=True,
with_flops=True,
with_stack=True,
with_modules=True,
)
else:
prof = contextlib.nullcontext()
57 changes: 48 additions & 9 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.comm import WaitComm
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
@@ -46,6 +47,8 @@

logger = logging.get_logger(__name__)

DOMINO_COMM_STREAM = "domino_comm_stream_{}"


class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
@@ -241,8 +244,8 @@ def __init__(
)
self.split_silu_mul = GLUActivation(config.hidden_act)

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states, handle_idx)
hidden_states, work = self.down_proj(self.split_silu_mul(merged_states))
return {"hidden_states": hidden_states, "work": work}

@@ -437,6 +440,7 @@ def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
handle_idx=None,
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
@@ -445,7 +449,7 @@ def forward(
)

qkv_states = self.qkv_proj(
hidden_states
hidden_states, handle_idx=handle_idx
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape

@@ -720,6 +724,18 @@ def __init__(
self.recompute_layer = parallel_config.recompute_layer
self.parallel_config = parallel_config

# if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1:
# from nanotron.parallel.comm import CudaStreamManager
# # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication
# CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device()))
num_gpus = torch.cuda.device_count()
for i in range(num_gpus):
from nanotron import constants

constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}"))

self.layer_idx = layer_idx

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
@@ -747,29 +763,52 @@ def _core_forward(
hidden_states0 = self.input_layernorm(hidden_states0)
hidden_states1 = self.input_layernorm(hidden_states1)

attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0)
attn_output0 = self.attn(
hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0"
)
attn_output0_work = attn_output0["work"]

attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1)
attn_output1 = self.attn(
hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1"
)
attn_output1_work = attn_output1["work"]

attn_output0_work.wait()
from nanotron import constants

comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
# comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device()))
with torch.cuda.stream(comm_stream):
attn_output0_work.wait()
# attn_output0_work.wait()

hidden_states0 = attn_output0["hidden_states"]
hidden_states0 = hidden_states0 + residual0
residual0 = hidden_states0
hidden_states0 = self.post_attention_layernorm(hidden_states0)
hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0")

# mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0")
mlp_output0 = self.mlp(hidden_states=hidden_states0)

attn_output1_work.wait()
with torch.cuda.stream(comm_stream):
attn_output1_work.wait()
# attn_output1_work.wait()

hidden_states1 = attn_output1["hidden_states"]
hidden_states1 = hidden_states1 + residual1
residual1 = hidden_states1
hidden_states1 = self.post_attention_layernorm(hidden_states1)
hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1")

# mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1")
mlp_output1 = self.mlp(hidden_states=hidden_states1)
mlp_output0["work"].wait()
mlp_output1["work"].wait()

with torch.cuda.stream(comm_stream):
mlp_output0["work"].wait()
mlp_output1["work"].wait()

# mlp_output0["work"].wait()
# mlp_output1["work"].wait()

hidden_states0 = mlp_output0["hidden_states"]
hidden_states1 = mlp_output1["hidden_states"]
46 changes: 46 additions & 0 deletions src/nanotron/parallel/comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
from contextlib import contextmanager
from typing import Dict

import torch


class CudaStreamManager:
_streams: Dict[str, "torch.cuda.Stream"] = {}

@staticmethod
def create(name: str):
assert name not in CudaStreamManager._streams
CudaStreamManager._streams[name] = torch.cuda.Stream()

@staticmethod
def get(name: str):
return CudaStreamManager._streams.get(name)

@contextmanager
def run_on_stream(name: str):
stream = CudaStreamManager.get(name)
with torch.cuda.stream(stream):
yield stream


class AsyncCommBucket:
"""
@@ -14,13 +36,37 @@ class AsyncCommBucket:

@staticmethod
def add(tensor_id: int, work: "dist.Work"):
assert (
tensor_id not in AsyncCommBucket._async_op
), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}"
AsyncCommBucket._async_op[tensor_id] = work

@staticmethod
def get(tensor_id: int):
return AsyncCommBucket._async_op.get(tensor_id)

@staticmethod
def pop(tensor_id: int):
return AsyncCommBucket._async_op.pop(tensor_id)

@staticmethod
def wait(tensor_id: int):
work = AsyncCommBucket._async_op.pop(tensor_id)
work.wait()


class WaitComm(torch.autograd.Function):
@staticmethod
def forward(ctx, input, wait_handle_idx):
ctx.wait_handle_idx = wait_handle_idx
return input

@staticmethod
def backward(ctx, grad_output):
import pydevd

pydevd.settrace(suspend=False, trace_only_current_thread=True)
if ctx.wait_handle_idx != "layer_1_batch_1":
handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
handle.wait()
return grad_output, None
Original file line number Diff line number Diff line change
@@ -26,22 +26,31 @@ class DifferentiableIdentity(torch.autograd.Function):
"""All-reduce gradients in a differentiable fashion"""

@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None):
# assert handle_idx is not None
ctx.handle_idx = handle_idx
ctx.group = group
return tensor

@staticmethod
def backward(ctx, grad_output):
# import pydevd
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
# NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async
# assert ctx.handle_idx is not None
group = ctx.group
return DifferentiableAllReduceSum.apply(grad_output, group, False), None
if ctx.handle_idx is not None:
assert 1 == 1

return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None


class DifferentiableAllReduceSum(torch.autograd.Function):
"""All-reduce in a differentiable fashion"""

@staticmethod
def forward(
ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool
ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None
) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
# ctx.mark_non_differentiable(async_all_reduce)
ctx.async_all_reduce = async_all_reduce
@@ -63,13 +72,17 @@ def forward(
if async_all_reduce:
# AsyncCommBucket.add(tensor, handle)
# AsyncCommBucket.add(id(tensor), handle)
AsyncCommBucket.add(orig_id, handle)
# try:
# AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
# except Exception as e:
# assert 1 == 1
AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)

return tensor

@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None
return grad_output, None, None, None


class DifferentiableAllGather(torch.autograd.Function):
@@ -151,8 +164,8 @@ def backward(ctx, grad_output):
# -----------------


def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableIdentity.apply(tensor, group)
def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None):
return DifferentiableIdentity.apply(tensor, group, handle_idx)


def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False):
6 changes: 4 additions & 2 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
@@ -436,12 +436,13 @@ def column_linear(
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
handle_idx: Optional[int] = None,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
input = differentiable_identity(input, group=group, handle_idx=handle_idx)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
@@ -604,7 +605,8 @@ def row_linear(
if async_all_reduce:
from nanotron.parallel.comm import AsyncCommBucket

work = AsyncCommBucket.get(orig_out_id)
# work = AsyncCommBucket.get(orig_out_id)
work = AsyncCommBucket.pop(orig_out_id)
assert 1 == 1
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
5 changes: 4 additions & 1 deletion src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ def __init__(
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
tp_recompute_allgather: bool = True,
# handle_idx: Optional[int] = None,
):
self.pg = pg
self.world_size = pg.size()
@@ -72,6 +73,7 @@ def __init__(

self.mode = mode
self.async_communication = async_communication
# self.handle_idx = handle_idx

if contiguous_chunks is not None:
assert (
@@ -85,7 +87,7 @@ def __init__(
split_config=split_config,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
return column_linear(
input=x,
weight=self.weight,
@@ -94,6 +96,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
handle_idx=handle_idx,
)

def extra_repr(self) -> str: