From d7bf8be86a6c9ae1b20ece6f90fcccac57e4f438 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 29 Jan 2025 12:47:04 +0000
Subject: [PATCH 01/17] first draft of domino forward pass

---
 examples/config_tiny_llama_domino.yaml        | 113 ++++++++++++++++++
 src/nanotron/config/parallelism_config.py     |  17 +++
 src/nanotron/models/llama.py                  |  90 +++++++++++---
 src/nanotron/optim/gradient_accumulator.py    |   4 +
 .../distributed_differentiable_primitives.py  |  21 ++--
 .../parallel/tensor_parallel/functional.py    |  12 +-
 src/nanotron/parallel/tensor_parallel/nn.py   |   5 +-
 7 files changed, 235 insertions(+), 27 deletions(-)
 create mode 100644 examples/config_tiny_llama_domino.yaml

diff --git a/examples/config_tiny_llama_domino.yaml b/examples/config_tiny_llama_domino.yaml
new file mode 100644
index 00000000..66e22dbd
--- /dev/null
+++ b/examples/config_tiny_llama_domino.yaml
@@ -0,0 +1,113 @@
+checkpoints:
+  checkpoint_interval: 10
+  checkpoints_path: checkpoints
+  checkpoints_path_is_shared_file_system: false
+  resume_checkpoint_path: null
+  save_initial_state: false
+data_stages:
+- data:
+    dataset:
+      dataset_overwrite_cache: false
+      dataset_processing_num_proc_per_process: 1
+      hf_dataset_config_name: null
+      hf_dataset_or_datasets: stas/openwebtext-10k
+      hf_dataset_splits: train
+      text_column_name: text
+    num_loading_workers: 1
+    seed: 42
+  name: Stable Training Stage
+  start_training_step: 1
+- data:
+    dataset:
+      dataset_overwrite_cache: false
+      dataset_processing_num_proc_per_process: 1
+      hf_dataset_config_name: null
+      hf_dataset_or_datasets: stas/openwebtext-10k
+      hf_dataset_splits: train
+      text_column_name: text
+    num_loading_workers: 1
+    seed: 42
+  name: Annealing Phase
+  start_training_step: 10
+general:
+  benchmark_csv_path: null
+  consumed_train_samples: null
+  ignore_sanity_checks: true
+  project: debug
+  run: tiny_llama_%date_%jobid
+  seed: 42
+  step: null
+lighteval: null
+logging:
+  iteration_step_info_interval: 1
+  log_level: info
+  log_level_replica: info
+model:
+  ddp_bucket_cap_mb: 25
+  dtype: bfloat16
+  init_method:
+    std: 0.025
+  make_vocab_size_divisible_by: 1
+  model_config:
+    bos_token_id: 1
+    eos_token_id: 2
+    hidden_act: silu
+    hidden_size: 16
+    initializer_range: 0.02
+    intermediate_size: 64
+    is_llama_config: true
+    max_position_embeddings: 256
+    num_attention_heads: 4
+    num_hidden_layers: 2
+    num_key_value_heads: 4
+    pad_token_id: null
+    pretraining_tp: 1
+    rms_norm_eps: 1.0e-05
+    rope_scaling: null
+    tie_word_embeddings: true
+    use_cache: true
+    vocab_size: 256
+optimizer:
+  accumulate_grad_in_fp32: true
+  clip_grad: 1.0
+  learning_rate_scheduler:
+    learning_rate: 0.0003
+    lr_decay_starting_step: null
+    lr_decay_steps: 13
+    lr_decay_style: cosine
+    lr_warmup_steps: 2
+    lr_warmup_style: linear
+    min_decay_lr: 1.0e-05
+  optimizer_factory:
+    adam_beta1: 0.9
+    adam_beta2: 0.95
+    adam_eps: 1.0e-08
+    name: adamW
+    torch_adam_is_fused: true
+  weight_decay: 0.01
+  zero_stage: 0
+parallelism:
+  # dp: 2
+  # pp: 2
+  dp: 1
+  pp: 1
+  tp: 2
+  expert_parallel_size: 1
+  pp_engine: 1f1b
+  tp_linear_async_communication: false
+  tp_mode: ALL_REDUCE
+  domino:
+    num_input_batches: 2
+profiler: null
+tokenizer:
+  tokenizer_max_length: null
+  tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel
+  tokenizer_revision: null
+tokens:
+  batch_accumulation_per_replica: 1
+  limit_test_batches: 0
+  limit_val_batches: 0
+  micro_batch_size: 2
+  sequence_length: 256
+  train_steps: 15
+  val_check_interval: -1
diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py
index 7f20ad99..2701bf9c 100644
--- a/src/nanotron/config/parallelism_config.py
+++ b/src/nanotron/config/parallelism_config.py
@@ -11,6 +11,22 @@
 from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
 
 
+@dataclass
+class DominoArgs:
+    """
+    Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping
+    https://arxiv.org/abs/2409.15241
+    """
+
+    # NOTE: if the number of input batches is 1,
+    # it's equivalent to non-domino mode
+    # so if you want to enable domino mode, set this to > 1
+    num_input_batches: int
+
+    def __post_init__(self):
+        assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1"
+
+
 @dataclass
 class ParallelismArgs:
     """Arguments related to TP/PP/DP
@@ -37,6 +53,7 @@ class ParallelismArgs:
     tp_recompute_allgather: bool = True
 
     expert_parallel_size: int = 1
+    domino: Optional[DominoArgs] = None
 
     def __post_init__(self):
         # Conservative defaults
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index 88fb6bcb..bc495624 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -237,13 +237,14 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
+            async_all_reduce=parallel_config.domino.num_input_batches > 1,
         )
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
     def forward(self, hidden_states):  # [seq_length, batch_size, hidden_dim]
         merged_states = self.gate_up_proj(hidden_states)
-        hidden_states = self.down_proj(self.split_silu_mul(merged_states))
-        return {"hidden_states": hidden_states}
+        hidden_states, work = self.down_proj(self.split_silu_mul(merged_states))
+        return {"hidden_states": hidden_states, "work": work}
 
 
 class CoreAttention(nn.Module):
@@ -335,6 +336,7 @@ def __init__(
         parallel_config: Optional[ParallelismArgs],
         tp_pg: dist.ProcessGroup,
         layer_idx: int,
+        async_all_reduce: bool = False,
     ):
         from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
 
@@ -418,6 +420,7 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication,
+            async_all_reduce=async_all_reduce,
         )
 
         self.attention = CoreAttention(
@@ -687,9 +690,9 @@ def forward(
         attention_output = (
             attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
         )
-        output = self.o_proj(attention_output)
+        output, work = self.o_proj(attention_output)
 
-        return {"hidden_states": output, "sequence_mask": sequence_mask}
+        return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask}
 
 
 class LlamaDecoderLayer(nn.Module):
@@ -702,36 +705,80 @@ def __init__(
     ):
         super().__init__()
         self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
         self.attn = CausalSelfAttention(
             config=config,
             parallel_config=parallel_config,
             tp_pg=tp_pg,
             layer_idx=layer_idx,
+            async_all_reduce=parallel_config.domino.num_input_batches > 1,
         )
 
         self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
 
         self.recompute_layer = parallel_config.recompute_layer
+        self.parallel_config = parallel_config
 
     def _core_forward(
         self,
         hidden_states: Union[torch.Tensor, TensorPointer],
         sequence_mask: Union[torch.Tensor, TensorPointer],
     ) -> List[Union[torch.Tensor, TensorPointer]]:
-        residual = hidden_states
-        hidden_states = self.input_layernorm(hidden_states)
 
-        output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
-        hidden_states = output["hidden_states"]
-        hidden_states = hidden_states + residual
+        num_input_batches = self.parallel_config.domino.num_input_batches
+        assert num_input_batches == 2
+        hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
+        orig_sequence_mask = sequence_mask
+        sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)
+
+        hidden_states0, hidden_states1 = hidden_states
+        sequence_mask0, sequence_mask1 = sequence_mask
+
+        # # Combine the chunks into a list of dictionaries
+        # hidden_encoder_states_list = [
+        #     {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]}
+        #     for i in range(num_input_batches)
+        # ]
+
+        residual0 = hidden_states0
+        residual1 = hidden_states1
+
+        hidden_states0 = self.input_layernorm(hidden_states0)
+        hidden_states1 = self.input_layernorm(hidden_states1)
+
+        attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0)
+        attn_output0_work = attn_output0["work"]
+
+        attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1)
+        attn_output1_work = attn_output1["work"]
+
+        attn_output0_work.wait()
+        hidden_states0 = attn_output0["hidden_states"]
+        hidden_states0 = hidden_states0 + residual0
+        residual0 = hidden_states0
+        hidden_states0 = self.post_attention_layernorm(hidden_states0)
 
-        residual = hidden_states
-        hidden_states = self.post_attention_layernorm(hidden_states)
-        hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
-        hidden_states = hidden_states + residual
+        mlp_output0 = self.mlp(hidden_states=hidden_states0)
 
-        return hidden_states, output["sequence_mask"]
+        attn_output1_work.wait()
+        hidden_states1 = attn_output1["hidden_states"]
+        hidden_states1 = hidden_states1 + residual1
+        residual1 = hidden_states1
+        hidden_states1 = self.post_attention_layernorm(hidden_states1)
+
+        mlp_output1 = self.mlp(hidden_states=hidden_states1)
+        mlp_output0["work"].wait()
+        mlp_output1["work"].wait()
+
+        hidden_states0 = mlp_output0["hidden_states"]
+        hidden_states1 = mlp_output1["hidden_states"]
+
+        hidden_states0 = hidden_states0 + residual0
+        hidden_states1 = hidden_states1 + residual1
+
+        hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
+        return hidden_states, orig_sequence_mask
 
     def _checkpointed_forward(
         self,
@@ -899,9 +946,24 @@ def forward_with_hidden_states(
             "hidden_states": output["input_embeds"],
             "sequence_mask": input_mask,
         }
+
+        # assert 1 == 1
+        # num_input_batches = self.parallel_config.domino.num_input_batches
+        # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1)
+        # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0)
+
+        # # Combine the chunks into a list of dictionaries
+        # hidden_encoder_states_list = [
+        #     {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]}
+        #     for i in range(num_input_batches)
+        # ]
+
         for encoder_block in self.decoder:
             hidden_encoder_states = encoder_block(**hidden_encoder_states)
 
+            # for hidden_encoder_states in hidden_encoder_states_list:
+            #     hidden_encoder_states = encoder_block(**hidden_encoder_states)
+
         hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
 
         sharded_logits = self.lm_head(x=hidden_states)["logits"]
diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py
index 2e940744..ba0d5dd0 100644
--- a/src/nanotron/optim/gradient_accumulator.py
+++ b/src/nanotron/optim/gradient_accumulator.py
@@ -202,6 +202,10 @@ def build_grad_buffers(
         return fp32_grad_buffers, contiguous_buffer_f32_gradients
 
     def backward(self, loss: torch.Tensor):
+        if isinstance(loss, tuple):
+            assert 1 == 1
+            raise NotImplementedError("Not implemented yet")
+
         result = loss.backward()
 
         for name, elt in self.fp32_grad_buffers.items():
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index bd41347a..0ae9a4de 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from typing import Optional, Tuple
 
 import torch
 from torch import distributed as torch_dist
@@ -32,23 +32,28 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
     @staticmethod
     def backward(ctx, grad_output):
         group = ctx.group
-        return DifferentiableAllReduceSum.apply(grad_output, group), None
+        return DifferentiableAllReduceSum.apply(grad_output, group, False), None
 
 
 class DifferentiableAllReduceSum(torch.autograd.Function):
     """All-reduce in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup]):
+    def forward(
+        ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool
+    ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
         if group.size() == 1:
             return tensor
 
-        dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
-        return tensor
+        handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce)
+        if async_all_reduce:
+            return tensor, handle
+        else:
+            return tensor, None
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None
+        return grad_output, None, None
 
 
 class DifferentiableAllGather(torch.autograd.Function):
@@ -134,8 +139,8 @@ def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
     return DifferentiableIdentity.apply(tensor, group)
 
 
-def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None):
-    return DifferentiableAllReduceSum.apply(tensor, group)
+def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False):
+    return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce)
 
 
 def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index e2ee3a29..ffb3e3f9 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import math
-from typing import Optional
+from typing import Optional, Tuple
 
 import torch
 from torch.nn import functional as F
@@ -587,18 +587,22 @@ def row_linear(
     bias: Optional[torch.Tensor],
     group: dist.ProcessGroup,
     tp_mode: TensorParallelLinearMode,
+    # TODO(xrsrke): use less confusing names for these arguments
     async_communication: bool,
-):
+    async_all_reduce: bool,
+) -> Tuple[torch.Tensor, Optional[torch.Future]]:
     if async_communication:
         return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
 
     out = F.linear(input, weight, bias)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        out = differentiable_all_reduce_sum(out, group=group)
+        out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
+        assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
         out = differentiable_reduce_scatter_sum(out, group=group)
+        work = None
     else:
         raise ValueError(f"Got unexpected mode: {tp_mode}.")
 
-    return out
+    return out, work
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index 4c7325cd..19cbdf88 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -111,6 +111,7 @@ def __init__(
         device=None,
         dtype=None,
         async_communication: bool = False,
+        async_all_reduce: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
     ):
         self.pg = pg
@@ -133,6 +134,7 @@ def __init__(
         )
         self.mode = mode
         self.async_communication = async_communication
+        self.async_all_reduce = async_all_reduce
         if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
             raise ValueError("async_communication is not supported for ALL_REDUCE mode")
 
@@ -166,6 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             group=self.pg,
             tp_mode=self.mode,
             async_communication=self.async_communication,
+            async_all_reduce=self.async_all_reduce,
         )
 
     def extra_repr(self) -> str:
@@ -290,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
             out = out * (~input_mask[..., None])
 
         if self.mode is TensorParallelLinearMode.ALL_REDUCE:
-            out = differentiable_all_reduce_sum(out, group=self.pg)
+            out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
         elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
             out = differentiable_reduce_scatter_sum(out, group=self.pg)
         else:

From 3803b1927e4063c2b329857dfbd0497778a235b1 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Thu, 30 Jan 2025 14:03:36 +0000
Subject: [PATCH 02/17] support the backward pass

---
 src/nanotron/optim/gradient_accumulator.py    |  2 +-
 src/nanotron/parallel/comm.py                 | 26 +++++++++++++++++++
 .../parallel/pipeline_parallel/engine.py      |  8 ++++--
 .../distributed_differentiable_primitives.py  | 22 +++++++++++++---
 .../parallel/tensor_parallel/functional.py    | 10 ++++++-
 src/nanotron/parallel/tensor_parallel/nn.py   |  2 +-
 6 files changed, 62 insertions(+), 8 deletions(-)
 create mode 100644 src/nanotron/parallel/comm.py

diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py
index ba0d5dd0..b5ef7d89 100644
--- a/src/nanotron/optim/gradient_accumulator.py
+++ b/src/nanotron/optim/gradient_accumulator.py
@@ -202,7 +202,7 @@ def build_grad_buffers(
         return fp32_grad_buffers, contiguous_buffer_f32_gradients
 
     def backward(self, loss: torch.Tensor):
-        if isinstance(loss, tuple):
+        if not isinstance(loss, torch.Tensor):
             assert 1 == 1
             raise NotImplementedError("Not implemented yet")
 
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
new file mode 100644
index 00000000..76e33e21
--- /dev/null
+++ b/src/nanotron/parallel/comm.py
@@ -0,0 +1,26 @@
+from typing import Dict
+
+
+class AsyncCommBucket:
+    """
+
+    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
+    RuntimeError: expected Variable or None (got tuple)
+        Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
+    RuntimeError: expected Variable or None (got tuple)
+    """
+
+    _async_op: Dict[int, "dist.Work"] = {}
+
+    @staticmethod
+    def add(tensor_id: int, work: "dist.Work"):
+        AsyncCommBucket._async_op[tensor_id] = work
+
+    @staticmethod
+    def get(tensor_id: int):
+        return AsyncCommBucket._async_op.get(tensor_id)
+
+    @staticmethod
+    def wait(tensor_id: int):
+        work = AsyncCommBucket._async_op.pop(tensor_id)
+        work.wait()
diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py
index ca9df312..8160f302 100644
--- a/src/nanotron/parallel/pipeline_parallel/engine.py
+++ b/src/nanotron/parallel/pipeline_parallel/engine.py
@@ -2,6 +2,9 @@
 from typing import Dict, Iterable, Optional, Union
 
 import torch
+from torch import nn as torch_nn
+from torch.nn.parallel import DistributedDataParallel
+
 from nanotron import distributed as dist
 from nanotron import logging
 from nanotron.distributed import ProcessGroup
@@ -12,8 +15,6 @@
 from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
 from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
 from nanotron.utils import ContextManagers
-from torch import nn as torch_nn
-from torch.nn.parallel import DistributedDataParallel
 
 logger = logging.get_logger(__name__)
 
@@ -83,6 +84,9 @@ def backward(
             if grad_accumulator is None:
                 sum(activations).backward()
             else:
+                # if not isinstance(activations, torch.Tensor):
+                #     raise NotImplementedError("Only support sum of tensors for now")
+
                 grad_accumulator.backward(sum(activations))
 
         # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 0ae9a4de..05ade53d 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -19,6 +19,7 @@
 
 from nanotron import distributed as dist
 from nanotron.distributed import ProcessGroup
+from nanotron.parallel.comm import AsyncCommBucket
 
 
 class DifferentiableIdentity(torch.autograd.Function):
@@ -42,14 +43,29 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
     def forward(
         ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool
     ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
+        # ctx.mark_non_differentiable(async_all_reduce)
+        ctx.async_all_reduce = async_all_reduce
+
         if group.size() == 1:
             return tensor
 
+        orig_id = id(tensor)
         handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce)
+        # if async_all_reduce:
+        #     handle.wait()
+        new_id = id(tensor)
+        assert 1 == 1
+        assert orig_id == new_id
+        # if async_all_reduce:
+        #     return tensor, handle
+        # else:
+        #     return tensor, None
         if async_all_reduce:
-            return tensor, handle
-        else:
-            return tensor, None
+            # AsyncCommBucket.add(tensor, handle)
+            # AsyncCommBucket.add(id(tensor), handle)
+            AsyncCommBucket.add(orig_id, handle)
+
+        return tensor
 
     @staticmethod
     def backward(ctx, grad_output):
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index ffb3e3f9..b05a50bf 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -597,7 +597,15 @@ def row_linear(
     out = F.linear(input, weight, bias)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
+        # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
+        orig_out_id = id(out)
+        # NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
+        out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
+        if async_all_reduce:
+            from nanotron.parallel.comm import AsyncCommBucket
+
+            work = AsyncCommBucket.get(orig_out_id)
+            assert 1 == 1
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
         out = differentiable_reduce_scatter_sum(out, group=group)
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index 19cbdf88..14a7486a 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -293,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
             out = out * (~input_mask[..., None])
 
         if self.mode is TensorParallelLinearMode.ALL_REDUCE:
-            out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
+            out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
         elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
             out = differentiable_reduce_scatter_sum(out, group=self.pg)
         else:

From d765fd57e29a30b0a083012bed6e443a2df72b7f Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Fri, 31 Jan 2025 14:20:32 +0000
Subject: [PATCH 03/17] the first draft for bwd overlapping

---
 src/nanotron/constants.py                     |  3 +
 src/nanotron/helpers.py                       |  2 +
 src/nanotron/models/llama.py                  | 57 ++++++++++++++++---
 src/nanotron/parallel/comm.py                 | 46 +++++++++++++++
 .../distributed_differentiable_primitives.py  | 27 ++++++---
 .../parallel/tensor_parallel/functional.py    |  6 +-
 src/nanotron/parallel/tensor_parallel/nn.py   |  5 +-
 7 files changed, 127 insertions(+), 19 deletions(-)

diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py
index 580bd99d..78fd0bb9 100644
--- a/src/nanotron/constants.py
+++ b/src/nanotron/constants.py
@@ -10,3 +10,6 @@
 
 CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
 MODEL_CONFIG_FILE_NAME = "model_config.json"
+
+
+CUDA_STREAMS = {}
diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py
index 73ca3484..7f31d812 100644
--- a/src/nanotron/helpers.py
+++ b/src/nanotron/helpers.py
@@ -482,7 +482,9 @@ def get_profiler(config: Config):
             on_trace_ready=on_trace_ready,
             # record_shapes=True,
             # profile_memory=True,
+            with_flops=True,
             with_stack=True,
+            with_modules=True,
         )
     else:
         prof = contextlib.nullcontext()
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index bc495624..fb112f8e 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -30,6 +30,7 @@
 from nanotron.nn.activations import ACT2FN
 from nanotron.nn.layer_norm import TritonRMSNorm
 from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import WaitComm
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
 from nanotron.parallel.pipeline_parallel.p2p import P2P
@@ -46,6 +47,8 @@
 
 logger = logging.get_logger(__name__)
 
+DOMINO_COMM_STREAM = "domino_comm_stream_{}"
+
 
 class RotaryEmbedding(nn.Module):
     def __init__(self, dim: int, end: int, theta: float = 10000.0):
@@ -241,8 +244,8 @@ def __init__(
         )
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
-    def forward(self, hidden_states):  # [seq_length, batch_size, hidden_dim]
-        merged_states = self.gate_up_proj(hidden_states)
+    def forward(self, hidden_states, handle_idx=None):  # [seq_length, batch_size, hidden_dim]
+        merged_states = self.gate_up_proj(hidden_states, handle_idx)
         hidden_states, work = self.down_proj(self.split_silu_mul(merged_states))
         return {"hidden_states": hidden_states, "work": work}
 
@@ -437,6 +440,7 @@ def forward(
         self,
         hidden_states,  # [seq_length, batch_size, hidden_size]
         sequence_mask,  # [batch_size, seq_length]
+        handle_idx=None,
     ):
         from flash_attn import bert_padding
         from flash_attn.flash_attn_interface import (
@@ -445,7 +449,7 @@ def forward(
         )
 
         qkv_states = self.qkv_proj(
-            hidden_states
+            hidden_states, handle_idx=handle_idx
         )  # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
         q_length, batch_size, _ = qkv_states.shape
 
@@ -720,6 +724,18 @@ def __init__(
         self.recompute_layer = parallel_config.recompute_layer
         self.parallel_config = parallel_config
 
+        # if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1:
+        #     from nanotron.parallel.comm import CudaStreamManager
+        #     # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication
+        #     CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device()))
+        num_gpus = torch.cuda.device_count()
+        for i in range(num_gpus):
+            from nanotron import constants
+
+            constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}"))
+
+        self.layer_idx = layer_idx
+
     def _core_forward(
         self,
         hidden_states: Union[torch.Tensor, TensorPointer],
@@ -747,29 +763,52 @@ def _core_forward(
         hidden_states0 = self.input_layernorm(hidden_states0)
         hidden_states1 = self.input_layernorm(hidden_states1)
 
-        attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0)
+        attn_output0 = self.attn(
+            hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0"
+        )
         attn_output0_work = attn_output0["work"]
 
-        attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1)
+        attn_output1 = self.attn(
+            hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1"
+        )
         attn_output1_work = attn_output1["work"]
 
-        attn_output0_work.wait()
+        from nanotron import constants
+
+        comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
+        # comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device()))
+        with torch.cuda.stream(comm_stream):
+            attn_output0_work.wait()
+        # attn_output0_work.wait()
+
         hidden_states0 = attn_output0["hidden_states"]
         hidden_states0 = hidden_states0 + residual0
         residual0 = hidden_states0
         hidden_states0 = self.post_attention_layernorm(hidden_states0)
+        hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0")
 
+        # mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0")
         mlp_output0 = self.mlp(hidden_states=hidden_states0)
 
-        attn_output1_work.wait()
+        with torch.cuda.stream(comm_stream):
+            attn_output1_work.wait()
+        # attn_output1_work.wait()
+
         hidden_states1 = attn_output1["hidden_states"]
         hidden_states1 = hidden_states1 + residual1
         residual1 = hidden_states1
         hidden_states1 = self.post_attention_layernorm(hidden_states1)
+        hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1")
 
+        # mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1")
         mlp_output1 = self.mlp(hidden_states=hidden_states1)
-        mlp_output0["work"].wait()
-        mlp_output1["work"].wait()
+
+        with torch.cuda.stream(comm_stream):
+            mlp_output0["work"].wait()
+            mlp_output1["work"].wait()
+
+        # mlp_output0["work"].wait()
+        # mlp_output1["work"].wait()
 
         hidden_states0 = mlp_output0["hidden_states"]
         hidden_states1 = mlp_output1["hidden_states"]
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 76e33e21..e26e2814 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -1,5 +1,27 @@
+from contextlib import contextmanager
 from typing import Dict
 
+import torch
+
+
+class CudaStreamManager:
+    _streams: Dict[str, "torch.cuda.Stream"] = {}
+
+    @staticmethod
+    def create(name: str):
+        assert name not in CudaStreamManager._streams
+        CudaStreamManager._streams[name] = torch.cuda.Stream()
+
+    @staticmethod
+    def get(name: str):
+        return CudaStreamManager._streams.get(name)
+
+    @contextmanager
+    def run_on_stream(name: str):
+        stream = CudaStreamManager.get(name)
+        with torch.cuda.stream(stream):
+            yield stream
+
 
 class AsyncCommBucket:
     """
@@ -14,13 +36,37 @@ class AsyncCommBucket:
 
     @staticmethod
     def add(tensor_id: int, work: "dist.Work"):
+        assert (
+            tensor_id not in AsyncCommBucket._async_op
+        ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}"
         AsyncCommBucket._async_op[tensor_id] = work
 
     @staticmethod
     def get(tensor_id: int):
         return AsyncCommBucket._async_op.get(tensor_id)
 
+    @staticmethod
+    def pop(tensor_id: int):
+        return AsyncCommBucket._async_op.pop(tensor_id)
+
     @staticmethod
     def wait(tensor_id: int):
         work = AsyncCommBucket._async_op.pop(tensor_id)
         work.wait()
+
+
+class WaitComm(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, wait_handle_idx):
+        ctx.wait_handle_idx = wait_handle_idx
+        return input
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        import pydevd
+
+        pydevd.settrace(suspend=False, trace_only_current_thread=True)
+        if ctx.wait_handle_idx != "layer_1_batch_1":
+            handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
+            handle.wait()
+        return grad_output, None
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 05ade53d..38c6bafd 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -26,14 +26,23 @@ class DifferentiableIdentity(torch.autograd.Function):
     """All-reduce gradients in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup]):
+    def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None):
+        # assert handle_idx is not None
+        ctx.handle_idx = handle_idx
         ctx.group = group
         return tensor
 
     @staticmethod
     def backward(ctx, grad_output):
+        # import pydevd
+        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
+        # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async
+        # assert ctx.handle_idx is not None
         group = ctx.group
-        return DifferentiableAllReduceSum.apply(grad_output, group, False), None
+        if ctx.handle_idx is not None:
+            assert 1 == 1
+
+        return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None
 
 
 class DifferentiableAllReduceSum(torch.autograd.Function):
@@ -41,7 +50,7 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
 
     @staticmethod
     def forward(
-        ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool
+        ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None
     ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
         # ctx.mark_non_differentiable(async_all_reduce)
         ctx.async_all_reduce = async_all_reduce
@@ -63,13 +72,17 @@ def forward(
         if async_all_reduce:
             # AsyncCommBucket.add(tensor, handle)
             # AsyncCommBucket.add(id(tensor), handle)
-            AsyncCommBucket.add(orig_id, handle)
+            # try:
+            #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
+            # except Exception as e:
+            #     assert 1 == 1
+            AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
 
         return tensor
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None, None
+        return grad_output, None, None, None
 
 
 class DifferentiableAllGather(torch.autograd.Function):
@@ -151,8 +164,8 @@ def backward(ctx, grad_output):
 # -----------------
 
 
-def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
-    return DifferentiableIdentity.apply(tensor, group)
+def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None):
+    return DifferentiableIdentity.apply(tensor, group, handle_idx)
 
 
 def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False):
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index b05a50bf..ff43c98b 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -436,12 +436,13 @@ def column_linear(
     tp_mode: TensorParallelLinearMode,
     async_communication: bool,
     tp_recompute_allgather: bool = True,
+    handle_idx: Optional[int] = None,
 ):
     if async_communication:
         return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        input = differentiable_identity(input, group=group)
+        input = differentiable_identity(input, group=group, handle_idx=handle_idx)
         return F.linear(input, weight, bias)
     if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
@@ -604,7 +605,8 @@ def row_linear(
         if async_all_reduce:
             from nanotron.parallel.comm import AsyncCommBucket
 
-            work = AsyncCommBucket.get(orig_out_id)
+            # work = AsyncCommBucket.get(orig_out_id)
+            work = AsyncCommBucket.pop(orig_out_id)
             assert 1 == 1
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index 14a7486a..f4ceff63 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -52,6 +52,7 @@ def __init__(
         async_communication: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
         tp_recompute_allgather: bool = True,
+        # handle_idx: Optional[int] = None,
     ):
         self.pg = pg
         self.world_size = pg.size()
@@ -72,6 +73,7 @@ def __init__(
 
         self.mode = mode
         self.async_communication = async_communication
+        # self.handle_idx = handle_idx
 
         if contiguous_chunks is not None:
             assert (
@@ -85,7 +87,7 @@ def __init__(
             split_config=split_config,
         )
 
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
         return column_linear(
             input=x,
             weight=self.weight,
@@ -94,6 +96,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             tp_mode=self.mode,
             async_communication=self.async_communication,
             tp_recompute_allgather=self.tp_recompute_allgather,
+            handle_idx=handle_idx,
         )
 
     def extra_repr(self) -> str:

From 9924608fccfddcf6bb87548610653407bc581ef5 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Mon, 3 Feb 2025 09:29:03 +0000
Subject: [PATCH 04/17] add backward pass overlapping

---
 src/nanotron/models/llama.py  | 21 +++++++++++++--------
 src/nanotron/parallel/comm.py | 12 +++++++++---
 src/nanotron/trainer.py       |  4 ++++
 3 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index fb112f8e..8dbcc661 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -764,12 +764,16 @@ def _core_forward(
         hidden_states1 = self.input_layernorm(hidden_states1)
 
         attn_output0 = self.attn(
-            hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0"
+            hidden_states=hidden_states0,
+            sequence_mask=sequence_mask0,
+            handle_idx=f"layer_attn_{self.layer_idx}_batch_0",
         )
         attn_output0_work = attn_output0["work"]
 
         attn_output1 = self.attn(
-            hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1"
+            hidden_states=hidden_states1,
+            sequence_mask=sequence_mask1,
+            handle_idx=f"layer_attn_{self.layer_idx}_batch_1",
         )
         attn_output1_work = attn_output1["work"]
 
@@ -785,10 +789,11 @@ def _core_forward(
         hidden_states0 = hidden_states0 + residual0
         residual0 = hidden_states0
         hidden_states0 = self.post_attention_layernorm(hidden_states0)
-        hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0")
+        hidden_states0 = WaitComm.apply(hidden_states0, f"layer_mlp_{self.layer_idx}_batch_1")
 
-        # mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0")
-        mlp_output0 = self.mlp(hidden_states=hidden_states0)
+        mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_mlp_{self.layer_idx}_batch_0")
+        mlp_output0 = WaitComm.apply(mlp_output0, f"layer_mlp_{self.layer_idx}_batch_1")
+        # mlp_output0 = self.mlp(hidden_states=hidden_states0)
 
         with torch.cuda.stream(comm_stream):
             attn_output1_work.wait()
@@ -798,10 +803,10 @@ def _core_forward(
         hidden_states1 = hidden_states1 + residual1
         residual1 = hidden_states1
         hidden_states1 = self.post_attention_layernorm(hidden_states1)
-        hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1")
+        # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1")
 
-        # mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1")
-        mlp_output1 = self.mlp(hidden_states=hidden_states1)
+        mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_mlp_{self.layer_idx}_batch_1")
+        # mlp_output1 = self.mlp(hidden_states=hidden_states1)
 
         with torch.cuda.stream(comm_stream):
             mlp_output0["work"].wait()
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index e26e2814..dd99f20a 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -54,6 +54,10 @@ def wait(tensor_id: int):
         work = AsyncCommBucket._async_op.pop(tensor_id)
         work.wait()
 
+    @staticmethod
+    def clear_all():
+        AsyncCommBucket._async_op.clear()
+
 
 class WaitComm(torch.autograd.Function):
     @staticmethod
@@ -63,10 +67,12 @@ def forward(ctx, input, wait_handle_idx):
 
     @staticmethod
     def backward(ctx, grad_output):
-        import pydevd
+        # import pydevd
 
-        pydevd.settrace(suspend=False, trace_only_current_thread=True)
-        if ctx.wait_handle_idx != "layer_1_batch_1":
+        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
+        # if ctx.wait_handle_idx != "layer_1_batch_1":
+        if ctx.wait_handle_idx != "layer_30_batch_1":
             handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
             handle.wait()
+            # assert 1 == 1
         return grad_output, None
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index 94b03c6e..52d5df48 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -578,6 +578,10 @@ def training_step(
 
         self.post_train_step()
 
+        from nanotron.parallel.comm import AsyncCommBucket
+
+        AsyncCommBucket.clear_all()
+
         return outputs, loss_avg
 
     def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:

From d6bc8da4a5df9f35b3345148639a4775b90de568 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Tue, 4 Feb 2025 14:32:38 +0000
Subject: [PATCH 05/17] fix some ops dont execute in the bwd pass

---
 examples/config_tiny_llama_domino.yaml        | 113 ------------------
 src/nanotron/models/llama.py                  |  49 ++++++--
 src/nanotron/parallel/comm.py                 |  23 +++-
 .../distributed_differentiable_primitives.py  |  84 ++++++++-----
 .../parallel/tensor_parallel/functional.py    |  14 ++-
 src/nanotron/parallel/tensor_parallel/nn.py   |   6 +-
 6 files changed, 129 insertions(+), 160 deletions(-)
 delete mode 100644 examples/config_tiny_llama_domino.yaml

diff --git a/examples/config_tiny_llama_domino.yaml b/examples/config_tiny_llama_domino.yaml
deleted file mode 100644
index 66e22dbd..00000000
--- a/examples/config_tiny_llama_domino.yaml
+++ /dev/null
@@ -1,113 +0,0 @@
-checkpoints:
-  checkpoint_interval: 10
-  checkpoints_path: checkpoints
-  checkpoints_path_is_shared_file_system: false
-  resume_checkpoint_path: null
-  save_initial_state: false
-data_stages:
-- data:
-    dataset:
-      dataset_overwrite_cache: false
-      dataset_processing_num_proc_per_process: 1
-      hf_dataset_config_name: null
-      hf_dataset_or_datasets: stas/openwebtext-10k
-      hf_dataset_splits: train
-      text_column_name: text
-    num_loading_workers: 1
-    seed: 42
-  name: Stable Training Stage
-  start_training_step: 1
-- data:
-    dataset:
-      dataset_overwrite_cache: false
-      dataset_processing_num_proc_per_process: 1
-      hf_dataset_config_name: null
-      hf_dataset_or_datasets: stas/openwebtext-10k
-      hf_dataset_splits: train
-      text_column_name: text
-    num_loading_workers: 1
-    seed: 42
-  name: Annealing Phase
-  start_training_step: 10
-general:
-  benchmark_csv_path: null
-  consumed_train_samples: null
-  ignore_sanity_checks: true
-  project: debug
-  run: tiny_llama_%date_%jobid
-  seed: 42
-  step: null
-lighteval: null
-logging:
-  iteration_step_info_interval: 1
-  log_level: info
-  log_level_replica: info
-model:
-  ddp_bucket_cap_mb: 25
-  dtype: bfloat16
-  init_method:
-    std: 0.025
-  make_vocab_size_divisible_by: 1
-  model_config:
-    bos_token_id: 1
-    eos_token_id: 2
-    hidden_act: silu
-    hidden_size: 16
-    initializer_range: 0.02
-    intermediate_size: 64
-    is_llama_config: true
-    max_position_embeddings: 256
-    num_attention_heads: 4
-    num_hidden_layers: 2
-    num_key_value_heads: 4
-    pad_token_id: null
-    pretraining_tp: 1
-    rms_norm_eps: 1.0e-05
-    rope_scaling: null
-    tie_word_embeddings: true
-    use_cache: true
-    vocab_size: 256
-optimizer:
-  accumulate_grad_in_fp32: true
-  clip_grad: 1.0
-  learning_rate_scheduler:
-    learning_rate: 0.0003
-    lr_decay_starting_step: null
-    lr_decay_steps: 13
-    lr_decay_style: cosine
-    lr_warmup_steps: 2
-    lr_warmup_style: linear
-    min_decay_lr: 1.0e-05
-  optimizer_factory:
-    adam_beta1: 0.9
-    adam_beta2: 0.95
-    adam_eps: 1.0e-08
-    name: adamW
-    torch_adam_is_fused: true
-  weight_decay: 0.01
-  zero_stage: 0
-parallelism:
-  # dp: 2
-  # pp: 2
-  dp: 1
-  pp: 1
-  tp: 2
-  expert_parallel_size: 1
-  pp_engine: 1f1b
-  tp_linear_async_communication: false
-  tp_mode: ALL_REDUCE
-  domino:
-    num_input_batches: 2
-profiler: null
-tokenizer:
-  tokenizer_max_length: null
-  tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel
-  tokenizer_revision: null
-tokens:
-  batch_accumulation_per_replica: 1
-  limit_test_batches: 0
-  limit_val_batches: 0
-  micro_batch_size: 2
-  sequence_length: 256
-  train_steps: 15
-  val_check_interval: -1
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index 8dbcc661..47394240 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -49,6 +49,11 @@
 
 DOMINO_COMM_STREAM = "domino_comm_stream_{}"
 
+FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
+BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}"
+FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
+BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
+
 
 class RotaryEmbedding(nn.Module):
     def __init__(self, dim: int, end: int, theta: float = 10000.0):
@@ -245,8 +250,8 @@ def __init__(
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
     def forward(self, hidden_states, handle_idx=None):  # [seq_length, batch_size, hidden_dim]
-        merged_states = self.gate_up_proj(hidden_states, handle_idx)
-        hidden_states, work = self.down_proj(self.split_silu_mul(merged_states))
+        merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx)
+        hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx)
         return {"hidden_states": hidden_states, "work": work}
 
 
@@ -449,7 +454,7 @@ def forward(
         )
 
         qkv_states = self.qkv_proj(
-            hidden_states, handle_idx=handle_idx
+            hidden_states, async_all_reduce=True, handle_idx=handle_idx
         )  # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
         q_length, batch_size, _ = qkv_states.shape
 
@@ -694,7 +699,7 @@ def forward(
         attention_output = (
             attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
         )
-        output, work = self.o_proj(attention_output)
+        output, work = self.o_proj(attention_output, handle_idx=handle_idx)
 
         return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask}
 
@@ -766,14 +771,26 @@ def _core_forward(
         attn_output0 = self.attn(
             hidden_states=hidden_states0,
             sequence_mask=sequence_mask0,
-            handle_idx=f"layer_attn_{self.layer_idx}_batch_0",
+            # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_0",
+            handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
+        )
+        attn_output0["hidden_states"] = WaitComm.apply(
+            attn_output0["hidden_states"],
+            # f"bwd.layer_attn_{self.layer_idx}_batch_1"
+            BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
         )
         attn_output0_work = attn_output0["work"]
 
         attn_output1 = self.attn(
             hidden_states=hidden_states1,
             sequence_mask=sequence_mask1,
-            handle_idx=f"layer_attn_{self.layer_idx}_batch_1",
+            # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_1",
+            handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
+        )
+        attn_output1["hidden_states"] = WaitComm.apply(
+            attn_output1["hidden_states"],
+            # f"bwd.layer_mlp_{self.layer_idx}_batch_0"
+            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
         )
         attn_output1_work = attn_output1["work"]
 
@@ -789,10 +806,18 @@ def _core_forward(
         hidden_states0 = hidden_states0 + residual0
         residual0 = hidden_states0
         hidden_states0 = self.post_attention_layernorm(hidden_states0)
-        hidden_states0 = WaitComm.apply(hidden_states0, f"layer_mlp_{self.layer_idx}_batch_1")
+        # hidden_states0 = WaitComm.apply(hidden_states0, f"bwd.layer_mlp_{self.layer_idx}_batch_0")
 
-        mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_mlp_{self.layer_idx}_batch_0")
-        mlp_output0 = WaitComm.apply(mlp_output0, f"layer_mlp_{self.layer_idx}_batch_1")
+        mlp_output0 = self.mlp(
+            hidden_states=hidden_states0,
+            # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_0"
+            handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
+        )
+        mlp_output0["hidden_states"] = WaitComm.apply(
+            mlp_output0["hidden_states"],
+            # f"bwd.layer_mlp_{self.layer_idx}_batch_1"
+            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+        )
         # mlp_output0 = self.mlp(hidden_states=hidden_states0)
 
         with torch.cuda.stream(comm_stream):
@@ -805,7 +830,11 @@ def _core_forward(
         hidden_states1 = self.post_attention_layernorm(hidden_states1)
         # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1")
 
-        mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_mlp_{self.layer_idx}_batch_1")
+        mlp_output1 = self.mlp(
+            hidden_states=hidden_states1,
+            # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_1"
+            handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+        )
         # mlp_output1 = self.mlp(hidden_states=hidden_states1)
 
         with torch.cuda.stream(comm_stream):
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index dd99f20a..c2b18e05 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -47,6 +47,7 @@ def get(tensor_id: int):
 
     @staticmethod
     def pop(tensor_id: int):
+        assert tensor_id in AsyncCommBucket._async_op, f"tensor_id: {tensor_id}"
         return AsyncCommBucket._async_op.pop(tensor_id)
 
     @staticmethod
@@ -59,6 +60,17 @@ def clear_all():
         AsyncCommBucket._async_op.clear()
 
 
+def is_async_comm(x):
+    import re
+
+    NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"]
+
+    patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX]  # Replace {} with regex for numbers
+    regex = re.compile("^(" + "|".join(patterns) + ")$")  # Combine patterns into a single regex
+    not_async = bool(regex.match(x))
+    return not not_async
+
+
 class WaitComm(torch.autograd.Function):
     @staticmethod
     def forward(ctx, input, wait_handle_idx):
@@ -68,11 +80,16 @@ def forward(ctx, input, wait_handle_idx):
     @staticmethod
     def backward(ctx, grad_output):
         # import pydevd
-
         # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-        # if ctx.wait_handle_idx != "layer_1_batch_1":
-        if ctx.wait_handle_idx != "layer_30_batch_1":
+
+        if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
+            assert 1 == 1
+
+        # if ctx.wait_handle_idx != "bwd.layer_mlp_1_batch_1":
+        # if ctx.wait_handle_idx != "layer_30_batch_1":
+        if is_async_comm(ctx.wait_handle_idx):
             handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
+            assert handle is not None
             handle.wait()
             # assert 1 == 1
         return grad_output, None
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 38c6bafd..3badfb34 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -26,8 +26,9 @@ class DifferentiableIdentity(torch.autograd.Function):
     """All-reduce gradients in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None):
+    def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None):
         # assert handle_idx is not None
+        ctx.async_all_reduce = async_all_reduce
         ctx.handle_idx = handle_idx
         ctx.group = group
         return tensor
@@ -39,10 +40,31 @@ def backward(ctx, grad_output):
         # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async
         # assert ctx.handle_idx is not None
         group = ctx.group
-        if ctx.handle_idx is not None:
-            assert 1 == 1
 
-        return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None
+        if ctx.handle_idx is not None and "fwd." in ctx.handle_idx:
+            handle_idx = ctx.handle_idx.replace("fwd.", "bwd.")
+            # if "bwd.layer_mlp_1_batch_1" == handle_idx:
+            #     from nanotron.parallel.comm import is_async_comm
+            #     async_all_reduce = is_async_comm(handle_idx)
+            # else:
+            #     async_all_reduce = ctx.async_all_reduce
+            from nanotron.parallel.comm import is_async_comm
+
+            async_all_reduce = is_async_comm(handle_idx)
+        else:
+            handle_idx = ctx.handle_idx
+            async_all_reduce = ctx.async_all_reduce
+
+        return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None
+
+
+def is_last_batch_of_attn(x):
+    import re
+
+    pattern = r"layer_attn_\d+_batch_0"
+    if re.match(pattern, x):
+        return True
+    return False
 
 
 class DifferentiableAllReduceSum(torch.autograd.Function):
@@ -52,31 +74,33 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
     def forward(
         ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None
     ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
-        # ctx.mark_non_differentiable(async_all_reduce)
         ctx.async_all_reduce = async_all_reduce
 
         if group.size() == 1:
             return tensor
 
-        orig_id = id(tensor)
-        handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce)
-        # if async_all_reduce:
-        #     handle.wait()
-        new_id = id(tensor)
-        assert 1 == 1
-        assert orig_id == new_id
-        # if async_all_reduce:
-        #     return tensor, handle
-        # else:
-        #     return tensor, None
-        if async_all_reduce:
-            # AsyncCommBucket.add(tensor, handle)
-            # AsyncCommBucket.add(id(tensor), handle)
-            # try:
-            #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
-            # except Exception as e:
-            #     assert 1 == 1
-            AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
+        if handle_idx == "bwd.layer_mlp_1_batch_0":
+            assert 1 == 1
+
+        id(tensor)
+        if async_all_reduce is True:
+            if isinstance(handle_idx, str):
+                do_async = is_last_batch_of_attn(handle_idx) is False
+            else:
+                do_async = async_all_reduce
+
+            handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
+            if do_async:
+                # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
+                # if handle_idx is not None and "bwd." in handle_idx:
+                #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
+                # else:
+                #     AsyncCommBucket.add(orig_id, handle)
+                # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
+                assert handle_idx is not None
+                AsyncCommBucket.add(handle_idx, handle)
+        else:
+            dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
 
         return tensor
 
@@ -164,12 +188,16 @@ def backward(ctx, grad_output):
 # -----------------
 
 
-def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None):
-    return DifferentiableIdentity.apply(tensor, group, handle_idx)
+def differentiable_identity(
+    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None
+):
+    return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx)
 
 
-def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False):
-    return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce)
+def differentiable_all_reduce_sum(
+    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None
+):
+    return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx)
 
 
 def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index ff43c98b..a3f1248e 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -436,13 +436,14 @@ def column_linear(
     tp_mode: TensorParallelLinearMode,
     async_communication: bool,
     tp_recompute_allgather: bool = True,
+    async_all_reduce: bool = False,
     handle_idx: Optional[int] = None,
 ):
     if async_communication:
         return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        input = differentiable_identity(input, group=group, handle_idx=handle_idx)
+        input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx)
         return F.linear(input, weight, bias)
     if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
@@ -591,6 +592,7 @@ def row_linear(
     # TODO(xrsrke): use less confusing names for these arguments
     async_communication: bool,
     async_all_reduce: bool,
+    handle_idx=None,
 ) -> Tuple[torch.Tensor, Optional[torch.Future]]:
     if async_communication:
         return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
@@ -599,14 +601,18 @@ def row_linear(
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
         # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
-        orig_out_id = id(out)
+        id(out)
         # NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
-        out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
+        out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx)
         if async_all_reduce:
             from nanotron.parallel.comm import AsyncCommBucket
 
             # work = AsyncCommBucket.get(orig_out_id)
-            work = AsyncCommBucket.pop(orig_out_id)
+            # work = AsyncCommBucket.pop(orig_out_id)
+            if handle_idx == "fwd.layer_mlp_1_batch_0":
+                assert 1 == 1
+
+            work = AsyncCommBucket.pop(handle_idx)
             assert 1 == 1
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index f4ceff63..847454fd 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -87,7 +87,7 @@ def __init__(
             split_config=split_config,
         )
 
-    def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> torch.Tensor:
         return column_linear(
             input=x,
             weight=self.weight,
@@ -96,6 +96,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
             tp_mode=self.mode,
             async_communication=self.async_communication,
             tp_recompute_allgather=self.tp_recompute_allgather,
+            async_all_reduce=async_all_reduce,
             handle_idx=handle_idx,
         )
 
@@ -163,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
                 )
             setattr(self, name, new_param)
 
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
         return row_linear(
             input=x,
             weight=self.weight,
@@ -172,6 +173,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             tp_mode=self.mode,
             async_communication=self.async_communication,
             async_all_reduce=self.async_all_reduce,
+            handle_idx=handle_idx,
         )
 
     def extra_repr(self) -> str:

From 93b2f106bb91e4f9328546d77dfc169348eb12b9 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 5 Feb 2025 10:49:03 +0000
Subject: [PATCH 06/17] fix can't find an ops in fwd

---
 src/nanotron/parallel/comm.py                         |  8 +++++++-
 .../distributed_differentiable_primitives.py          | 11 +++++++----
 src/nanotron/parallel/tensor_parallel/functional.py   |  6 +++++-
 3 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index c2b18e05..459ea23d 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -63,7 +63,13 @@ def clear_all():
 def is_async_comm(x):
     import re
 
-    NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"]
+    NON_ASYNC_HANDLE_IDX = [
+        # "fwd.layer_attn_{}_batch_0",
+        # "fwd.layer_mlp_{}_batch_0",
+        # "fwd.layer_mlp_{}_batch_1",
+        "bwd.layer_mlp_{}_batch_1",
+        "bwd.layer_attn_{}_batch_0",
+    ]
 
     patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX]  # Replace {} with regex for numbers
     regex = re.compile("^(" + "|".join(patterns) + ")$")  # Combine patterns into a single regex
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 3badfb34..9d65878b 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -84,10 +84,13 @@ def forward(
 
         id(tensor)
         if async_all_reduce is True:
-            if isinstance(handle_idx, str):
-                do_async = is_last_batch_of_attn(handle_idx) is False
-            else:
-                do_async = async_all_reduce
+            # if isinstance(handle_idx, str):
+            #     do_async = is_last_batch_of_attn(handle_idx) is False
+            # else:
+            #     do_async = async_all_reduce
+            from nanotron.parallel.comm import is_async_comm
+
+            do_async = is_async_comm(handle_idx)
 
             handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
             if do_async:
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index a3f1248e..f0ca3a0d 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -603,13 +603,17 @@ def row_linear(
         # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
         id(out)
         # NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
+        if handle_idx == "fwd.layer_attn_0_batch_0":
+            assert 1 == 1
+
         out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx)
         if async_all_reduce:
             from nanotron.parallel.comm import AsyncCommBucket
 
             # work = AsyncCommBucket.get(orig_out_id)
             # work = AsyncCommBucket.pop(orig_out_id)
-            if handle_idx == "fwd.layer_mlp_1_batch_0":
+            # if handle_idx == "fwd.layer_mlp_1_batch_0":
+            if handle_idx == "fwd.layer_attn_0_batch_0":
                 assert 1 == 1
 
             work = AsyncCommBucket.pop(handle_idx)

From 31db05dafaf3190c1e3c33b4eb0b87cb9e5f0f04 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 5 Feb 2025 16:49:44 +0000
Subject: [PATCH 07/17] partially overlapping bwd pass

---
 src/nanotron/models/llama.py                  | 166 ++++++++++++------
 src/nanotron/parallel/comm.py                 |   5 +-
 .../distributed_differentiable_primitives.py  |   6 +
 src/nanotron/trainer.py                       |   3 +-
 4 files changed, 123 insertions(+), 57 deletions(-)

diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index 47394240..acbece96 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -50,9 +50,9 @@
 DOMINO_COMM_STREAM = "domino_comm_stream_{}"
 
 FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
-BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}"
 FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
 BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
+BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}"
 
 
 class RotaryEmbedding(nn.Module):
@@ -741,116 +741,176 @@ def __init__(
 
         self.layer_idx = layer_idx
 
+    # def _core_forward(
+    #     self,
+    #     hidden_states: Union[torch.Tensor, TensorPointer],
+    #     sequence_mask: Union[torch.Tensor, TensorPointer],
+    # ) -> List[Union[torch.Tensor, TensorPointer]]:
+    #     from nanotron import constants
+
+    #     num_input_batches = self.parallel_config.domino.num_input_batches
+    #     orig_sequence_mask = sequence_mask
+
+    #     assert num_input_batches == 2
+    #     hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
+    #     sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)
+
+    #     hidden_states0, hidden_states1 = hidden_states
+    #     sequence_mask0, sequence_mask1 = sequence_mask
+
+    #     residual0 = hidden_states0
+    #     residual1 = hidden_states1
+
+    #     hidden_states0 = self.input_layernorm(hidden_states0)
+    #     hidden_states1 = self.input_layernorm(hidden_states1)
+
+    #     attn_output0 = self.attn(
+    #         hidden_states=hidden_states0,
+    #         sequence_mask=sequence_mask0,
+    #         handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
+    #     )
+    #     # attn_output0["hidden_states"] = WaitComm.apply(
+    #     #     attn_output0["hidden_states"],
+    #     #     BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
+    #     # )
+
+    #     attn_output1 = self.attn(
+    #         hidden_states=hidden_states1,
+    #         sequence_mask=sequence_mask1,
+    #         handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
+    #     )
+    #     # attn_output1["hidden_states"] = WaitComm.apply(
+    #     #     attn_output1["hidden_states"],
+    #     #     BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
+    #     # )
+
+    #     comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
+    #     with torch.cuda.stream(comm_stream):
+    #         attn_output0["work"].wait()
+
+    #     hidden_states0 = attn_output0["hidden_states"] + residual0
+    #     residual0 = hidden_states0
+    #     hidden_states0 = self.post_attention_layernorm(hidden_states0)
+    #     hidden_states0 = WaitComm.apply(
+    #         hidden_states0,
+    #         BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+    #     ) # new
+
+    #     mlp_output0 = self.mlp(
+    #         hidden_states=hidden_states0,
+    #         handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
+    #     )
+    #     # mlp_output0["hidden_states"] = WaitComm.apply(
+    #     #     mlp_output0["hidden_states"],
+    #     #     BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+    #     # )
+
+    #     with torch.cuda.stream(comm_stream):
+    #         attn_output1["work"].wait()
+
+    #     hidden_states1 = attn_output1["hidden_states"] + residual1
+    #     residual1 = hidden_states1
+    #     hidden_states1 = self.post_attention_layernorm(hidden_states1)
+    #     hidden_states1 = WaitComm.apply(
+    #         hidden_states1,
+    #         BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
+    #     )
+
+    #     mlp_output1 = self.mlp(
+    #         hidden_states=hidden_states1,
+    #         handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+    #     )
+
+    #     with torch.cuda.stream(comm_stream):
+    #         mlp_output0["work"].wait()
+    #         mlp_output1["work"].wait()
+
+    #     hidden_states0 = mlp_output0["hidden_states"] + residual0
+    #     hidden_states1 = mlp_output1["hidden_states"] + residual1
+
+    #     hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
+    #     return hidden_states, orig_sequence_mask
+
     def _core_forward(
         self,
         hidden_states: Union[torch.Tensor, TensorPointer],
         sequence_mask: Union[torch.Tensor, TensorPointer],
     ) -> List[Union[torch.Tensor, TensorPointer]]:
+        from nanotron import constants
 
         num_input_batches = self.parallel_config.domino.num_input_batches
+        orig_sequence_mask = sequence_mask
+
         assert num_input_batches == 2
         hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
-        orig_sequence_mask = sequence_mask
         sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)
 
         hidden_states0, hidden_states1 = hidden_states
         sequence_mask0, sequence_mask1 = sequence_mask
 
-        # # Combine the chunks into a list of dictionaries
-        # hidden_encoder_states_list = [
-        #     {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]}
-        #     for i in range(num_input_batches)
-        # ]
-
         residual0 = hidden_states0
         residual1 = hidden_states1
 
         hidden_states0 = self.input_layernorm(hidden_states0)
         hidden_states1 = self.input_layernorm(hidden_states1)
+        hidden_states0 = WaitComm.apply(
+            hidden_states0,
+            BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
+        )
+        hidden_states1 = WaitComm.apply(
+            hidden_states1,
+            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
+        )
 
         attn_output0 = self.attn(
             hidden_states=hidden_states0,
             sequence_mask=sequence_mask0,
-            # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_0",
             handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
         )
-        attn_output0["hidden_states"] = WaitComm.apply(
-            attn_output0["hidden_states"],
-            # f"bwd.layer_attn_{self.layer_idx}_batch_1"
-            BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
-        )
-        attn_output0_work = attn_output0["work"]
-
         attn_output1 = self.attn(
             hidden_states=hidden_states1,
             sequence_mask=sequence_mask1,
-            # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_1",
             handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
         )
-        attn_output1["hidden_states"] = WaitComm.apply(
-            attn_output1["hidden_states"],
-            # f"bwd.layer_mlp_{self.layer_idx}_batch_0"
-            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
-        )
-        attn_output1_work = attn_output1["work"]
-
-        from nanotron import constants
 
         comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
-        # comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device()))
         with torch.cuda.stream(comm_stream):
-            attn_output0_work.wait()
-        # attn_output0_work.wait()
+            attn_output0["work"].wait()
 
-        hidden_states0 = attn_output0["hidden_states"]
-        hidden_states0 = hidden_states0 + residual0
+        hidden_states0 = attn_output0["hidden_states"] + residual0
         residual0 = hidden_states0
         hidden_states0 = self.post_attention_layernorm(hidden_states0)
-        # hidden_states0 = WaitComm.apply(hidden_states0, f"bwd.layer_mlp_{self.layer_idx}_batch_0")
+        hidden_states0 = WaitComm.apply(
+            hidden_states0,
+            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+        )  # new
 
         mlp_output0 = self.mlp(
             hidden_states=hidden_states0,
-            # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_0"
             handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
         )
-        mlp_output0["hidden_states"] = WaitComm.apply(
-            mlp_output0["hidden_states"],
-            # f"bwd.layer_mlp_{self.layer_idx}_batch_1"
-            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
-        )
-        # mlp_output0 = self.mlp(hidden_states=hidden_states0)
 
         with torch.cuda.stream(comm_stream):
-            attn_output1_work.wait()
-        # attn_output1_work.wait()
+            attn_output1["work"].wait()
 
-        hidden_states1 = attn_output1["hidden_states"]
-        hidden_states1 = hidden_states1 + residual1
+        hidden_states1 = attn_output1["hidden_states"] + residual1
         residual1 = hidden_states1
         hidden_states1 = self.post_attention_layernorm(hidden_states1)
-        # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1")
 
         mlp_output1 = self.mlp(
             hidden_states=hidden_states1,
-            # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_1"
             handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
         )
-        # mlp_output1 = self.mlp(hidden_states=hidden_states1)
 
         with torch.cuda.stream(comm_stream):
             mlp_output0["work"].wait()
             mlp_output1["work"].wait()
 
-        # mlp_output0["work"].wait()
-        # mlp_output1["work"].wait()
-
-        hidden_states0 = mlp_output0["hidden_states"]
-        hidden_states1 = mlp_output1["hidden_states"]
-
-        hidden_states0 = hidden_states0 + residual0
-        hidden_states1 = hidden_states1 + residual1
+        hidden_states0 = mlp_output0["hidden_states"] + residual0
+        hidden_states1 = mlp_output1["hidden_states"] + residual1
 
         hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
+        assert 1 == 1
         return hidden_states, orig_sequence_mask
 
     def _checkpointed_forward(
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 459ea23d..b00f6e9e 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -91,11 +91,10 @@ def backward(ctx, grad_output):
         if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
             assert 1 == 1
 
-        # if ctx.wait_handle_idx != "bwd.layer_mlp_1_batch_1":
-        # if ctx.wait_handle_idx != "layer_30_batch_1":
         if is_async_comm(ctx.wait_handle_idx):
             handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
             assert handle is not None
             handle.wait()
-            # assert 1 == 1
+            # assert handle.is_completed() is True
+
         return grad_output, None
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 9d65878b..c4f69c05 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -55,6 +55,9 @@ def backward(ctx, grad_output):
             handle_idx = ctx.handle_idx
             async_all_reduce = ctx.async_all_reduce
 
+        if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
+            assert 1 == 1
+
         return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None
 
 
@@ -94,6 +97,9 @@ def forward(
 
             handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
             if do_async:
+                if "bwd" in handle_idx:
+                    assert 1 == 1
+
                 # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
                 # if handle_idx is not None and "bwd." in handle_idx:
                 #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index 52d5df48..96af6b12 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -580,7 +580,8 @@ def training_step(
 
         from nanotron.parallel.comm import AsyncCommBucket
 
-        AsyncCommBucket.clear_all()
+        assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}"
+        # AsyncCommBucket.clear_all()
 
         return outputs, loss_avg
 

From 23f210815e06147cf95fff3fede4641cc1c43101 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Mon, 10 Feb 2025 16:33:23 +0000
Subject: [PATCH 08/17] fix stream not sync

---
 src/nanotron/constants.py                     |   4 +
 src/nanotron/models/llama.py                  |  46 ++++----
 src/nanotron/parallel/comm.py                 |  15 ++-
 src/nanotron/parallel/dependency.py           | 102 ++++++++++++++++++
 .../distributed_differentiable_primitives.py  |   4 +
 src/nanotron/parallel/tensor_parallel/nn.py   |   8 +-
 src/nanotron/trainer.py                       |  17 ++-
 7 files changed, 172 insertions(+), 24 deletions(-)
 create mode 100644 src/nanotron/parallel/dependency.py

diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py
index 78fd0bb9..3fe440a8 100644
--- a/src/nanotron/constants.py
+++ b/src/nanotron/constants.py
@@ -13,3 +13,7 @@
 
 
 CUDA_STREAMS = {}
+
+CLOCK = 0
+_AUTOGRAD_RUNS = []
+_NOT_BWD_ASYNC_OPS = []
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index acbece96..72ebf478 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -245,13 +245,15 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
-            async_all_reduce=parallel_config.domino.num_input_batches > 1,
+            # async_all_reduce=parallel_config.domino.num_input_batches > 1,
         )
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
     def forward(self, hidden_states, handle_idx=None):  # [seq_length, batch_size, hidden_dim]
         merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx)
-        hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx)
+        hidden_states, work = self.down_proj(
+            self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx
+        )
         return {"hidden_states": hidden_states, "work": work}
 
 
@@ -428,7 +430,7 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication,
-            async_all_reduce=async_all_reduce,
+            # async_all_reduce=async_all_reduce,
         )
 
         self.attention = CoreAttention(
@@ -699,7 +701,7 @@ def forward(
         attention_output = (
             attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
         )
-        output, work = self.o_proj(attention_output, handle_idx=handle_idx)
+        output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx)
 
         return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask}
 
@@ -876,6 +878,7 @@ def _core_forward(
         comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
         with torch.cuda.stream(comm_stream):
             attn_output0["work"].wait()
+            attn_output0["work"].is_completed()
 
         hidden_states0 = attn_output0["hidden_states"] + residual0
         residual0 = hidden_states0
@@ -890,8 +893,16 @@ def _core_forward(
             handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
         )
 
+        # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend(
+        #     run_after=attn_output1["hidden_states"],
+        #     run_before=mlp_output0["hidden_states"]
+        # )
+
         with torch.cuda.stream(comm_stream):
             attn_output1["work"].wait()
+            attn_output1["work"].is_completed()
+
+        torch.cuda.current_stream().wait_stream(comm_stream)
 
         hidden_states1 = attn_output1["hidden_states"] + residual1
         residual1 = hidden_states1
@@ -906,11 +917,24 @@ def _core_forward(
             mlp_output0["work"].wait()
             mlp_output1["work"].wait()
 
+            mlp_output0["work"].is_completed()
+            mlp_output1["work"].is_completed()
+
+        torch.cuda.current_stream().wait_stream(comm_stream)
+
         hidden_states0 = mlp_output0["hidden_states"] + residual0
         hidden_states1 = mlp_output1["hidden_states"] + residual1
 
+        # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1)
+
         hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
         assert 1 == 1
+
+        # assert attn_output0["work"].is_completed()
+        # assert attn_output1["work"].is_completed()
+        # assert mlp_output0["work"].is_completed()
+        # assert mlp_output1["work"].is_completed()
+
         return hidden_states, orig_sequence_mask
 
     def _checkpointed_forward(
@@ -1080,23 +1104,9 @@ def forward_with_hidden_states(
             "sequence_mask": input_mask,
         }
 
-        # assert 1 == 1
-        # num_input_batches = self.parallel_config.domino.num_input_batches
-        # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1)
-        # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0)
-
-        # # Combine the chunks into a list of dictionaries
-        # hidden_encoder_states_list = [
-        #     {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]}
-        #     for i in range(num_input_batches)
-        # ]
-
         for encoder_block in self.decoder:
             hidden_encoder_states = encoder_block(**hidden_encoder_states)
 
-            # for hidden_encoder_states in hidden_encoder_states_list:
-            #     hidden_encoder_states = encoder_block(**hidden_encoder_states)
-
         hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
 
         sharded_logits = self.lm_head(x=hidden_states)["logits"]
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index b00f6e9e..789416c3 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -33,6 +33,7 @@ class AsyncCommBucket:
     """
 
     _async_op: Dict[int, "dist.Work"] = {}
+    _copy_async_op: Dict[int, "dist.Work"] = {}
 
     @staticmethod
     def add(tensor_id: int, work: "dist.Work"):
@@ -40,6 +41,7 @@ def add(tensor_id: int, work: "dist.Work"):
             tensor_id not in AsyncCommBucket._async_op
         ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}"
         AsyncCommBucket._async_op[tensor_id] = work
+        AsyncCommBucket._copy_async_op[tensor_id] = work
 
     @staticmethod
     def get(tensor_id: int):
@@ -58,6 +60,7 @@ def wait(tensor_id: int):
     @staticmethod
     def clear_all():
         AsyncCommBucket._async_op.clear()
+        AsyncCommBucket._copy_async_op.clear()
 
 
 def is_async_comm(x):
@@ -92,9 +95,19 @@ def backward(ctx, grad_output):
             assert 1 == 1
 
         if is_async_comm(ctx.wait_handle_idx):
+            from nanotron.constants import _AUTOGRAD_RUNS
+
+            _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}")
             handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
             assert handle is not None
             handle.wait()
-            # assert handle.is_completed() is True
+            # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}"
+        else:
+
+            from nanotron import constants
+
+            # if dist.get_rank() == 0:
+            #     constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
+            constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
 
         return grad_output, None
diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py
new file mode 100644
index 00000000..6a633d8a
--- /dev/null
+++ b/src/nanotron/parallel/dependency.py
@@ -0,0 +1,102 @@
+from typing import Dict, Tuple
+
+import torch
+from torch import Tensor
+
+_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
+
+
+def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
+    """Gets a phony. Phony is tensor without space. It is useful to make
+    arbitrary dependency in a autograd graph because it doesn't require any
+    gradient accumulation.
+
+    .. note::
+
+        Phonies for each device are cached. If an autograd function gets a phony
+        internally, the phony must be detached to be returned. Otherwise, the
+        autograd engine will mutate the cached phony in-place::
+
+            class Phonify(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    phony = get_phony(input.device, requires_grad=False)
+                    return phony.detach()  # detach() is necessary.
+
+    """
+    key = (device, requires_grad)
+
+    try:
+        phony = _phonies[key]
+    except KeyError:
+        with torch.cuda.stream(torch.cuda.default_stream(device)):
+            phony = torch.empty(0, device=device, requires_grad=requires_grad)
+
+        _phonies[key] = phony
+
+    return phony
+
+
+def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
+    """Branches out from an autograd lane of the given tensor."""
+    if torch.is_grad_enabled() and input.requires_grad:
+        input, phony = Fork.apply(input)
+    else:
+        phony = get_phony(input.device, requires_grad=False)
+
+    return input, phony
+
+
+class Fork(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore
+        phony = get_phony(input.device, requires_grad=False)
+        return input, phony.detach()
+
+    @staticmethod
+    def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor:  # type: ignore
+        # import pydevd
+        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
+        return grad_input
+
+
+def join(input: Tensor, phony: Tensor) -> Tensor:
+    """Merges two autograd lanes."""
+    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
+        input = Join.apply(input, phony)
+
+    return input
+
+
+class Join(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor:  # type: ignore
+        return input
+
+    @staticmethod
+    def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]:  # type: ignore
+        # import pydevd
+        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
+        return grad_input, None
+
+
+# def depend(fork_from, join_to) -> None:
+#     # Ensure that batches[i-1] is executed after batches[i] in
+#     # # backpropagation by an explicit dependency.
+#     # if i != 0:
+#     #     depend(batches[i-1], batches[i])
+#     # depend(run_after, run_before)
+#     fork_from, phony = fork(fork_from)
+#     join_to = join(join_to, phony)
+#     return fork_from, join_to
+
+
+def depend(run_after, run_before) -> None:
+    # Ensure that batches[i-1] is executed after batches[i] in
+    # # backpropagation by an explicit dependency.
+    # if i != 0:
+    #     depend(batches[i-1], batches[i])
+    # depend(run_after, run_before)
+    run_after, phony = fork(run_after)
+    run_before = join(run_before, phony)
+    return run_after, run_before
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index c4f69c05..58275368 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -58,6 +58,10 @@ def backward(ctx, grad_output):
         if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
             assert 1 == 1
 
+        from nanotron.constants import _AUTOGRAD_RUNS
+
+        _AUTOGRAD_RUNS.append(handle_idx)
+
         return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None
 
 
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index 847454fd..4fea1838 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -115,7 +115,7 @@ def __init__(
         device=None,
         dtype=None,
         async_communication: bool = False,
-        async_all_reduce: bool = False,
+        # async_all_reduce: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
     ):
         self.pg = pg
@@ -138,7 +138,7 @@ def __init__(
         )
         self.mode = mode
         self.async_communication = async_communication
-        self.async_all_reduce = async_all_reduce
+        # self.async_all_reduce = async_all_reduce
         if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
             raise ValueError("async_communication is not supported for ALL_REDUCE mode")
 
@@ -164,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
                 )
             setattr(self, name, new_param)
 
-    def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor:
         return row_linear(
             input=x,
             weight=self.weight,
@@ -172,7 +172,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
             group=self.pg,
             tp_mode=self.mode,
             async_communication=self.async_communication,
-            async_all_reduce=self.async_all_reduce,
+            async_all_reduce=async_all_reduce,
             handle_idx=handle_idx,
         )
 
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index 96af6b12..e58af9f5 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -564,6 +564,9 @@ def training_step(
             self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
         )
 
+        if dist.get_rank() == 0:
+            assert 1 == 1
+
         # Apply gradient
         self.optimizer.step()
         self.optimizer.zero_grad()
@@ -580,8 +583,20 @@ def training_step(
 
         from nanotron.parallel.comm import AsyncCommBucket
 
+        # import torch.distributed as dist
+
+        not_finished = []
+        for k, v in AsyncCommBucket._copy_async_op.items():
+            # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}"
+            if v.is_completed() is not True:
+                not_finished.append((k, v))
+
+        # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS:
+        #     assert 1 == 1
+
+        assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}"
         assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}"
-        # AsyncCommBucket.clear_all()
+        AsyncCommBucket.clear_all()
 
         return outputs, loss_avg
 

From eac4ac59360b57015c5c5cf43b7851024c5bc7f5 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Tue, 11 Feb 2025 14:47:21 +0000
Subject: [PATCH 09/17] add cuda stream syncronization for the bwd pass

---
 src/nanotron/models/llama.py  | 26 +++++++-------------------
 src/nanotron/parallel/comm.py |  6 ++++--
 src/nanotron/trainer.py       |  2 +-
 3 files changed, 12 insertions(+), 22 deletions(-)

diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index 72ebf478..bb10021d 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -842,6 +842,7 @@ def _core_forward(
 
         num_input_batches = self.parallel_config.domino.num_input_batches
         orig_sequence_mask = sequence_mask
+        comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
 
         assert num_input_batches == 2
         hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
@@ -855,14 +856,8 @@ def _core_forward(
 
         hidden_states0 = self.input_layernorm(hidden_states0)
         hidden_states1 = self.input_layernorm(hidden_states1)
-        hidden_states0 = WaitComm.apply(
-            hidden_states0,
-            BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
-        )
-        hidden_states1 = WaitComm.apply(
-            hidden_states1,
-            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
-        )
+        hidden_states0 = WaitComm.apply(hidden_states0, BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), comm_stream)
+        hidden_states1 = WaitComm.apply(hidden_states1, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), comm_stream)
 
         attn_output0 = self.attn(
             hidden_states=hidden_states0,
@@ -875,7 +870,6 @@ def _core_forward(
             handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
         )
 
-        comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
         with torch.cuda.stream(comm_stream):
             attn_output0["work"].wait()
             attn_output0["work"].is_completed()
@@ -884,8 +878,7 @@ def _core_forward(
         residual0 = hidden_states0
         hidden_states0 = self.post_attention_layernorm(hidden_states0)
         hidden_states0 = WaitComm.apply(
-            hidden_states0,
-            BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+            hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream
         )  # new
 
         mlp_output0 = self.mlp(
@@ -928,12 +921,6 @@ def _core_forward(
         # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1)
 
         hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
-        assert 1 == 1
-
-        # assert attn_output0["work"].is_completed()
-        # assert attn_output1["work"].is_completed()
-        # assert mlp_output0["work"].is_completed()
-        # assert mlp_output1["work"].is_completed()
 
         return hidden_states, orig_sequence_mask
 
@@ -1104,8 +1091,9 @@ def forward_with_hidden_states(
             "sequence_mask": input_mask,
         }
 
-        for encoder_block in self.decoder:
-            hidden_encoder_states = encoder_block(**hidden_encoder_states)
+        for layer_idx, encoder_block in enumerate(self.decoder):
+            with torch.profiler.record_function(f"layer_{layer_idx}"):
+                hidden_encoder_states = encoder_block(**hidden_encoder_states)
 
         hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
 
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 789416c3..38a3c134 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -82,8 +82,9 @@ def is_async_comm(x):
 
 class WaitComm(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, input, wait_handle_idx):
+    def forward(ctx, input, wait_handle_idx, comm_stream):
         ctx.wait_handle_idx = wait_handle_idx
+        ctx.comm_stream = comm_stream
         return input
 
     @staticmethod
@@ -101,6 +102,7 @@ def backward(ctx, grad_output):
             handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
             assert handle is not None
             handle.wait()
+            torch.cuda.default_stream().wait_stream(ctx.comm_stream)
             # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}"
         else:
 
@@ -110,4 +112,4 @@ def backward(ctx, grad_output):
             #     constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
             constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
 
-        return grad_output, None
+        return grad_output, None, None
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index e58af9f5..dfec12c0 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -594,7 +594,7 @@ def training_step(
         # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS:
         #     assert 1 == 1
 
-        assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}"
+        assert len(not_finished) == 0, f"len={len(not_finished)}, AsyncCommBucket._copy_async_op: {not_finished}"
         assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}"
         AsyncCommBucket.clear_all()
 

From 3a438ff238dd8da16e25a3bf83bd1856ee0d727b Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Tue, 11 Feb 2025 16:47:23 +0000
Subject: [PATCH 10/17] domino but
 non_async_last_batch_mlp_and_non_async_first_batch_attn

---
 src/nanotron/models/llama.py                  |  30 ++---
 src/nanotron/parallel/comm.py                 |  52 --------
 .../distributed_differentiable_primitives.py  | 111 ++++++++++--------
 .../parallel/tensor_parallel/functional.py    |  24 ++--
 src/nanotron/parallel/tensor_parallel/nn.py   |  18 ++-
 5 files changed, 101 insertions(+), 134 deletions(-)

diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index bb10021d..a6554bd8 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -30,10 +30,12 @@
 from nanotron.nn.activations import ACT2FN
 from nanotron.nn.layer_norm import TritonRMSNorm
 from nanotron.parallel import ParallelContext
-from nanotron.parallel.comm import WaitComm
+
+# from nanotron.parallel.comm import WaitComm
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
 from nanotron.parallel.pipeline_parallel.p2p import P2P
+from nanotron.parallel.tensor_parallel.domino import WaitComm
 from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
 from nanotron.parallel.tensor_parallel.nn import (
     TensorParallelColumnLinear,
@@ -249,11 +251,9 @@ def __init__(
         )
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
-    def forward(self, hidden_states, handle_idx=None):  # [seq_length, batch_size, hidden_dim]
-        merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx)
-        hidden_states, work = self.down_proj(
-            self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx
-        )
+    def forward(self, hidden_states, op_name):  # [seq_length, batch_size, hidden_dim]
+        merged_states = self.gate_up_proj(hidden_states, op_name=op_name)
+        hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name)
         return {"hidden_states": hidden_states, "work": work}
 
 
@@ -447,7 +447,7 @@ def forward(
         self,
         hidden_states,  # [seq_length, batch_size, hidden_size]
         sequence_mask,  # [batch_size, seq_length]
-        handle_idx=None,
+        op_name,
     ):
         from flash_attn import bert_padding
         from flash_attn.flash_attn_interface import (
@@ -456,7 +456,7 @@ def forward(
         )
 
         qkv_states = self.qkv_proj(
-            hidden_states, async_all_reduce=True, handle_idx=handle_idx
+            hidden_states, op_name=op_name
         )  # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
         q_length, batch_size, _ = qkv_states.shape
 
@@ -701,7 +701,7 @@ def forward(
         attention_output = (
             attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
         )
-        output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx)
+        output, work = self.o_proj(attention_output, op_name=op_name)
 
         return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask}
 
@@ -862,12 +862,12 @@ def _core_forward(
         attn_output0 = self.attn(
             hidden_states=hidden_states0,
             sequence_mask=sequence_mask0,
-            handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
+            op_name=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
         )
         attn_output1 = self.attn(
             hidden_states=hidden_states1,
             sequence_mask=sequence_mask1,
-            handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
+            op_name=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
         )
 
         with torch.cuda.stream(comm_stream):
@@ -883,7 +883,7 @@ def _core_forward(
 
         mlp_output0 = self.mlp(
             hidden_states=hidden_states0,
-            handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
+            op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
         )
 
         # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend(
@@ -903,15 +903,15 @@ def _core_forward(
 
         mlp_output1 = self.mlp(
             hidden_states=hidden_states1,
-            handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
+            op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
         )
 
         with torch.cuda.stream(comm_stream):
             mlp_output0["work"].wait()
-            mlp_output1["work"].wait()
+            # mlp_output1["work"].wait()
 
             mlp_output0["work"].is_completed()
-            mlp_output1["work"].is_completed()
+            # mlp_output1["work"].is_completed()
 
         torch.cuda.current_stream().wait_stream(comm_stream)
 
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 38a3c134..6dbb041f 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -61,55 +61,3 @@ def wait(tensor_id: int):
     def clear_all():
         AsyncCommBucket._async_op.clear()
         AsyncCommBucket._copy_async_op.clear()
-
-
-def is_async_comm(x):
-    import re
-
-    NON_ASYNC_HANDLE_IDX = [
-        # "fwd.layer_attn_{}_batch_0",
-        # "fwd.layer_mlp_{}_batch_0",
-        # "fwd.layer_mlp_{}_batch_1",
-        "bwd.layer_mlp_{}_batch_1",
-        "bwd.layer_attn_{}_batch_0",
-    ]
-
-    patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX]  # Replace {} with regex for numbers
-    regex = re.compile("^(" + "|".join(patterns) + ")$")  # Combine patterns into a single regex
-    not_async = bool(regex.match(x))
-    return not not_async
-
-
-class WaitComm(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, input, wait_handle_idx, comm_stream):
-        ctx.wait_handle_idx = wait_handle_idx
-        ctx.comm_stream = comm_stream
-        return input
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        # import pydevd
-        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-
-        if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
-            assert 1 == 1
-
-        if is_async_comm(ctx.wait_handle_idx):
-            from nanotron.constants import _AUTOGRAD_RUNS
-
-            _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}")
-            handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
-            assert handle is not None
-            handle.wait()
-            torch.cuda.default_stream().wait_stream(ctx.comm_stream)
-            # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}"
-        else:
-
-            from nanotron import constants
-
-            # if dist.get_rank() == 0:
-            #     constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
-            constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
-
-        return grad_output, None, None
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 58275368..5ac3bedf 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -20,16 +20,16 @@
 from nanotron import distributed as dist
 from nanotron.distributed import ProcessGroup
 from nanotron.parallel.comm import AsyncCommBucket
+from nanotron.parallel.tensor_parallel.domino import is_async_comm
 
 
 class DifferentiableIdentity(torch.autograd.Function):
     """All-reduce gradients in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None):
-        # assert handle_idx is not None
+    def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None):
         ctx.async_all_reduce = async_all_reduce
-        ctx.handle_idx = handle_idx
+        ctx.op_name = op_name
         ctx.group = group
         return tensor
 
@@ -41,28 +41,35 @@ def backward(ctx, grad_output):
         # assert ctx.handle_idx is not None
         group = ctx.group
 
-        if ctx.handle_idx is not None and "fwd." in ctx.handle_idx:
-            handle_idx = ctx.handle_idx.replace("fwd.", "bwd.")
-            # if "bwd.layer_mlp_1_batch_1" == handle_idx:
-            #     from nanotron.parallel.comm import is_async_comm
-            #     async_all_reduce = is_async_comm(handle_idx)
-            # else:
-            #     async_all_reduce = ctx.async_all_reduce
-            from nanotron.parallel.comm import is_async_comm
+        # if ctx.handle_idx is not None and "fwd." in ctx.handle_idx:
+        #     handle_idx = ctx.handle_idx.replace("fwd.", "bwd.")
+        #     # if "bwd.layer_mlp_1_batch_1" == handle_idx:
+        #     #     from nanotron.parallel.comm import is_async_comm
+        #     #     async_all_reduce = is_async_comm(handle_idx)
+        #     # else:
+        #     #     async_all_reduce = ctx.async_all_reduce
+        #     # from nanotron.parallel.comm import is_async_comm
+        #     from nanotron.parallel.tensor_parallel.domino import is_async_comm
 
-            async_all_reduce = is_async_comm(handle_idx)
-        else:
-            handle_idx = ctx.handle_idx
-            async_all_reduce = ctx.async_all_reduce
+        #     async_all_reduce = is_async_comm(handle_idx)
+        # else:
+        #     handle_idx = ctx.handle_idx
+        #     async_all_reduce = ctx.async_all_reduce
+
+        # if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
+        #     assert 1 == 1
+
+        op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
+        async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce
 
-        if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
+        if op_name is not None and "layer_mlp_27_batch_1" in op_name:
             assert 1 == 1
 
         from nanotron.constants import _AUTOGRAD_RUNS
 
-        _AUTOGRAD_RUNS.append(handle_idx)
+        _AUTOGRAD_RUNS.append(ctx.op_name)
 
-        return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None
+        return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name), None, None, None
 
 
 def is_last_batch_of_attn(x):
@@ -79,39 +86,45 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
 
     @staticmethod
     def forward(
-        ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None
+        ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None
     ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
         ctx.async_all_reduce = async_all_reduce
 
         if group.size() == 1:
             return tensor
 
-        if handle_idx == "bwd.layer_mlp_1_batch_0":
-            assert 1 == 1
-
-        id(tensor)
-        if async_all_reduce is True:
-            # if isinstance(handle_idx, str):
-            #     do_async = is_last_batch_of_attn(handle_idx) is False
-            # else:
-            #     do_async = async_all_reduce
-            from nanotron.parallel.comm import is_async_comm
-
-            do_async = is_async_comm(handle_idx)
-
-            handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
-            if do_async:
-                if "bwd" in handle_idx:
-                    assert 1 == 1
-
-                # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
-                # if handle_idx is not None and "bwd." in handle_idx:
-                #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
-                # else:
-                #     AsyncCommBucket.add(orig_id, handle)
-                # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
-                assert handle_idx is not None
-                AsyncCommBucket.add(handle_idx, handle)
+        # if handle_idx == "bwd.layer_mlp_1_batch_0":
+        #     assert 1 == 1
+
+        # id(tensor)
+        # if async_all_reduce is True:
+        #     # if isinstance(handle_idx, str):
+        #     #     do_async = is_last_batch_of_attn(handle_idx) is False
+        #     # else:
+        #     #     do_async = async_all_reduce
+        #     # from nanotron.parallel.comm import is_async_comm
+        #     from nanotron.parallel.tensor_parallel.domino import is_async_comm
+
+        #     do_async = is_async_comm(handle_idx)
+
+        #     handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
+        #     if do_async:
+        #         if "bwd" in handle_idx:
+        #             assert 1 == 1
+
+        #         # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
+        #         # if handle_idx is not None and "bwd." in handle_idx:
+        #         #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
+        #         # else:
+        #         #     AsyncCommBucket.add(orig_id, handle)
+        #         # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
+        #         assert handle_idx is not None
+        #         AsyncCommBucket.add(handle_idx, handle)
+        # else:
+        #     dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
+        if async_all_reduce:
+            handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
+            AsyncCommBucket.add(op_name, handle)
         else:
             dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
 
@@ -202,15 +215,15 @@ def backward(ctx, grad_output):
 
 
 def differentiable_identity(
-    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None
+    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None
 ):
-    return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx)
+    return DifferentiableIdentity.apply(tensor, group, async_all_reduce, op_name)
 
 
 def differentiable_all_reduce_sum(
-    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None
+    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None
 ):
-    return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx)
+    return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, op_name)
 
 
 def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index f0ca3a0d..6d69408c 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -437,13 +437,13 @@ def column_linear(
     async_communication: bool,
     tp_recompute_allgather: bool = True,
     async_all_reduce: bool = False,
-    handle_idx: Optional[int] = None,
+    op_name: Optional[str] = None,
 ):
     if async_communication:
         return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx)
+        input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, op_name=op_name)
         return F.linear(input, weight, bias)
     if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
@@ -592,7 +592,7 @@ def row_linear(
     # TODO(xrsrke): use less confusing names for these arguments
     async_communication: bool,
     async_all_reduce: bool,
-    handle_idx=None,
+    op_name: Optional[str] = None,
 ) -> Tuple[torch.Tensor, Optional[torch.Future]]:
     if async_communication:
         return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
@@ -601,23 +601,31 @@ def row_linear(
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
         # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
-        id(out)
+        # id(out)
         # NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
-        if handle_idx == "fwd.layer_attn_0_batch_0":
+        if op_name == "fwd.layer_attn_0_batch_0":
             assert 1 == 1
 
-        out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx)
+        if op_name == "fwd.layer_mlp_0_batch_1":
+            assert 1 == 1
+
+        if op_name == "fwd.layer_attn_0_batch_0":
+            assert 1 == 1
+
+        out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name)
         if async_all_reduce:
             from nanotron.parallel.comm import AsyncCommBucket
 
             # work = AsyncCommBucket.get(orig_out_id)
             # work = AsyncCommBucket.pop(orig_out_id)
             # if handle_idx == "fwd.layer_mlp_1_batch_0":
-            if handle_idx == "fwd.layer_attn_0_batch_0":
+            if op_name == "fwd.layer_attn_0_batch_0":
                 assert 1 == 1
 
-            work = AsyncCommBucket.pop(handle_idx)
+            work = AsyncCommBucket.pop(op_name)
             assert 1 == 1
+        else:
+            work = None
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
         out = differentiable_reduce_scatter_sum(out, group=group)
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index 4fea1838..2e6fd5a4 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -31,6 +31,7 @@
     differentiable_identity,
     differentiable_reduce_scatter_sum,
 )
+from nanotron.parallel.tensor_parallel.domino import is_async_comm
 from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
 from nanotron.parallel.tensor_parallel.functional import (
     column_linear,
@@ -52,7 +53,6 @@ def __init__(
         async_communication: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
         tp_recompute_allgather: bool = True,
-        # handle_idx: Optional[int] = None,
     ):
         self.pg = pg
         self.world_size = pg.size()
@@ -73,7 +73,6 @@ def __init__(
 
         self.mode = mode
         self.async_communication = async_communication
-        # self.handle_idx = handle_idx
 
         if contiguous_chunks is not None:
             assert (
@@ -87,7 +86,7 @@ def __init__(
             split_config=split_config,
         )
 
-    def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor:
         return column_linear(
             input=x,
             weight=self.weight,
@@ -96,8 +95,8 @@ def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> to
             tp_mode=self.mode,
             async_communication=self.async_communication,
             tp_recompute_allgather=self.tp_recompute_allgather,
-            async_all_reduce=async_all_reduce,
-            handle_idx=handle_idx,
+            async_all_reduce=False if op_name is None else is_async_comm(op_name),
+            op_name=op_name,
         )
 
     def extra_repr(self) -> str:
@@ -115,7 +114,6 @@ def __init__(
         device=None,
         dtype=None,
         async_communication: bool = False,
-        # async_all_reduce: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
     ):
         self.pg = pg
@@ -138,7 +136,7 @@ def __init__(
         )
         self.mode = mode
         self.async_communication = async_communication
-        # self.async_all_reduce = async_all_reduce
+
         if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
             raise ValueError("async_communication is not supported for ALL_REDUCE mode")
 
@@ -164,7 +162,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
                 )
             setattr(self, name, new_param)
 
-    def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor:
         return row_linear(
             input=x,
             weight=self.weight,
@@ -172,8 +170,8 @@ def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.T
             group=self.pg,
             tp_mode=self.mode,
             async_communication=self.async_communication,
-            async_all_reduce=async_all_reduce,
-            handle_idx=handle_idx,
+            async_all_reduce=False if op_name is None else is_async_comm(op_name),
+            op_name=op_name,
         )
 
     def extra_repr(self) -> str:

From da948dfab4690c13435e1c0ab7ca0ae957bb0451 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Tue, 11 Feb 2025 16:48:10 +0000
Subject: [PATCH 11/17] non_async_last_batch_mlp_and_non_async_first_batch_attn

---
 .../parallel/tensor_parallel/domino.py        | 60 +++++++++++++++++++
 tests/test_domino.py                          | 19 ++++++
 2 files changed, 79 insertions(+)
 create mode 100644 src/nanotron/parallel/tensor_parallel/domino.py
 create mode 100644 tests/test_domino.py

diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
new file mode 100644
index 00000000..7fb68abc
--- /dev/null
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -0,0 +1,60 @@
+import re
+
+import torch
+
+from nanotron.parallel.comm import AsyncCommBucket
+
+
+def is_async_comm(op_name: str):
+    """
+    There are two operations that we can't overlap
+    for the forward pass: the last micro-batch of the mlp layer
+    for the backward pass: the first micro-batch of the attention layer
+    """
+    NON_ASYNC_HANDLE_IDX = [
+        # "fwd.layer_attn_{}_batch_0",
+        # "fwd.layer_mlp_{}_batch_0",
+        "fwd.layer_mlp_{}_batch_1",
+        # "bwd.layer_mlp_{}_batch_1",
+        "bwd.layer_attn_{}_batch_0",
+    ]
+
+    patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX]  # Replace {} with regex for numbers
+    regex = re.compile("^(" + "|".join(patterns) + ")$")  # Combine patterns into a single regex
+    not_async = bool(regex.match(op_name))
+    return not not_async
+
+
+class WaitComm(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, wait_handle_idx, comm_stream):
+        ctx.wait_handle_idx = wait_handle_idx
+        ctx.comm_stream = comm_stream
+        return input
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        # import pydevd
+        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
+
+        if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
+            assert 1 == 1
+
+        if is_async_comm(ctx.wait_handle_idx):
+            from nanotron.constants import _AUTOGRAD_RUNS
+
+            _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}")
+            handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
+            assert handle is not None
+            handle.wait()
+            torch.cuda.default_stream().wait_stream(ctx.comm_stream)
+            # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}"
+        else:
+
+            from nanotron import constants
+
+            # if dist.get_rank() == 0:
+            #     constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
+            constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
+
+        return grad_output, None, None
diff --git a/tests/test_domino.py b/tests/test_domino.py
new file mode 100644
index 00000000..8f474ff8
--- /dev/null
+++ b/tests/test_domino.py
@@ -0,0 +1,19 @@
+import pytest
+from nanotron.parallel.tensor_parallel.domino import is_async_comm
+
+
+@pytest.mark.parametrize(
+    "op_name, expected",
+    [
+        ("fwd.layer_attn_1_batch_0", True),
+        ("fwd.layer_attn_1_batch_1", True),
+        ("fwd.layer_mlp_1_batch_0", True),
+        ("fwd.layer_mlp_1_batch_1", False),
+        ("bwd.layer_mlp_1_batch_1", True),
+        ("bwd.layer_mlp_1_batch_0", True),
+        ("bwd.layer_attn_1_batch_1", True),
+        ("bwd.layer_attn_1_batch_0", False),
+    ],
+)
+def test_is_async_comm(op_name, expected):
+    assert is_async_comm(op_name) == expected

From aa3e97393058c0b9bcc7bf65dccccc2c8a3f5c91 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 12 Feb 2025 13:21:51 +0000
Subject: [PATCH 12/17] backup before refactoring

---
 examples/config_llama_domino.yaml             |  98 ++++++++++++++++
 src/nanotron/models/llama.py                  | 107 ++----------------
 .../distributed_differentiable_primitives.py  | 105 ++++++-----------
 .../parallel/tensor_parallel/domino.py        |  15 ++-
 src/nanotron/trainer.py                       |   6 -
 5 files changed, 149 insertions(+), 182 deletions(-)
 create mode 100644 examples/config_llama_domino.yaml

diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml
new file mode 100644
index 00000000..b9811fdd
--- /dev/null
+++ b/examples/config_llama_domino.yaml
@@ -0,0 +1,98 @@
+checkpoints:
+  checkpoint_interval: 1000
+  checkpoints_path: checkpoints
+  checkpoints_path_is_shared_file_system: false
+  resume_checkpoint_path: null
+  save_initial_state: false
+data_stages:
+- data:
+    dataset:
+      dataset_overwrite_cache: false
+      dataset_processing_num_proc_per_process: 1
+      hf_dataset_config_name: null
+      hf_dataset_or_datasets: roneneldan/TinyStories
+      hf_dataset_splits: train
+      text_column_name: text
+    num_loading_workers: 1
+    seed: 42
+  name: Stable Training Stage
+  start_training_step: 1
+general:
+  benchmark_csv_path: null
+  consumed_train_samples: null
+  ignore_sanity_checks: true
+  project: nanotron_domino
+  run: config_llama_domino
+  seed: 42
+  step: null
+lighteval: null
+logging:
+  iteration_step_info_interval: 1
+  log_level: info
+  log_level_replica: info
+model:
+  ddp_bucket_cap_mb: 25
+  dtype: bfloat16
+  init_method:
+    std: 0.025
+  make_vocab_size_divisible_by: 1
+  model_config:
+    bos_token_id: 128000
+    eos_token_id: 128001
+    hidden_act: silu
+    hidden_size: 4096
+    initializer_range: 0.02
+    intermediate_size: 16384
+    is_llama_config: true
+    max_position_embeddings: 4096
+    num_attention_heads: 32
+    num_hidden_layers: 32
+    num_key_value_heads: 8
+    pad_token_id: null
+    pretraining_tp: 1
+    rms_norm_eps: 1.0e-05
+    rope_scaling: null
+    tie_word_embeddings: true
+    use_cache: true
+    vocab_size: 128256
+optimizer:
+  accumulate_grad_in_fp32: true
+  clip_grad: 1.0
+  learning_rate_scheduler:
+    learning_rate: 0.0003
+    lr_decay_starting_step: null
+    lr_decay_steps: 1000
+    lr_decay_style: cosine
+    lr_warmup_steps: 500
+    lr_warmup_style: linear
+    min_decay_lr: 1.0e-05
+  optimizer_factory:
+    adam_beta1: 0.9
+    adam_beta2: 0.95
+    adam_eps: 1.0e-08
+    name: adamW
+    torch_adam_is_fused: true
+  weight_decay: 0.01
+  zero_stage: 0
+parallelism:
+  dp: 1
+  pp: 1
+  tp: 8
+  expert_parallel_size: 1
+  pp_engine: 1f1b
+  tp_linear_async_communication: false
+  tp_mode: ALL_REDUCE
+  domino:
+    num_input_batches: 2
+tokenizer:
+  tokenizer_max_length: null
+  tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
+  tokenizer_revision: null
+tokens:
+  batch_accumulation_per_replica: 1
+  limit_test_batches: 0
+  limit_val_batches: 0
+  micro_batch_size: 2
+  sequence_length: 4096
+  train_steps: 1500
+  val_check_interval: -1
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index a6554bd8..b6785604 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -35,7 +35,13 @@
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
 from nanotron.parallel.pipeline_parallel.p2p import P2P
-from nanotron.parallel.tensor_parallel.domino import WaitComm
+from nanotron.parallel.tensor_parallel.domino import (
+    BWD_ATTN_HANDLE_IDX,
+    BWD_MLP_HANDLE_IDX,
+    FWD_ATTN_HANDLE_IDX,
+    FWD_MLP_HANDLE_IDX,
+    WaitComm,
+)
 from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
 from nanotron.parallel.tensor_parallel.nn import (
     TensorParallelColumnLinear,
@@ -51,11 +57,6 @@
 
 DOMINO_COMM_STREAM = "domino_comm_stream_{}"
 
-FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
-FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
-BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
-BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}"
-
 
 class RotaryEmbedding(nn.Module):
     def __init__(self, dim: int, end: int, theta: float = 10000.0):
@@ -743,96 +744,6 @@ def __init__(
 
         self.layer_idx = layer_idx
 
-    # def _core_forward(
-    #     self,
-    #     hidden_states: Union[torch.Tensor, TensorPointer],
-    #     sequence_mask: Union[torch.Tensor, TensorPointer],
-    # ) -> List[Union[torch.Tensor, TensorPointer]]:
-    #     from nanotron import constants
-
-    #     num_input_batches = self.parallel_config.domino.num_input_batches
-    #     orig_sequence_mask = sequence_mask
-
-    #     assert num_input_batches == 2
-    #     hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
-    #     sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)
-
-    #     hidden_states0, hidden_states1 = hidden_states
-    #     sequence_mask0, sequence_mask1 = sequence_mask
-
-    #     residual0 = hidden_states0
-    #     residual1 = hidden_states1
-
-    #     hidden_states0 = self.input_layernorm(hidden_states0)
-    #     hidden_states1 = self.input_layernorm(hidden_states1)
-
-    #     attn_output0 = self.attn(
-    #         hidden_states=hidden_states0,
-    #         sequence_mask=sequence_mask0,
-    #         handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
-    #     )
-    #     # attn_output0["hidden_states"] = WaitComm.apply(
-    #     #     attn_output0["hidden_states"],
-    #     #     BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
-    #     # )
-
-    #     attn_output1 = self.attn(
-    #         hidden_states=hidden_states1,
-    #         sequence_mask=sequence_mask1,
-    #         handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
-    #     )
-    #     # attn_output1["hidden_states"] = WaitComm.apply(
-    #     #     attn_output1["hidden_states"],
-    #     #     BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
-    #     # )
-
-    #     comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
-    #     with torch.cuda.stream(comm_stream):
-    #         attn_output0["work"].wait()
-
-    #     hidden_states0 = attn_output0["hidden_states"] + residual0
-    #     residual0 = hidden_states0
-    #     hidden_states0 = self.post_attention_layernorm(hidden_states0)
-    #     hidden_states0 = WaitComm.apply(
-    #         hidden_states0,
-    #         BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
-    #     ) # new
-
-    #     mlp_output0 = self.mlp(
-    #         hidden_states=hidden_states0,
-    #         handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
-    #     )
-    #     # mlp_output0["hidden_states"] = WaitComm.apply(
-    #     #     mlp_output0["hidden_states"],
-    #     #     BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
-    #     # )
-
-    #     with torch.cuda.stream(comm_stream):
-    #         attn_output1["work"].wait()
-
-    #     hidden_states1 = attn_output1["hidden_states"] + residual1
-    #     residual1 = hidden_states1
-    #     hidden_states1 = self.post_attention_layernorm(hidden_states1)
-    #     hidden_states1 = WaitComm.apply(
-    #         hidden_states1,
-    #         BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
-    #     )
-
-    #     mlp_output1 = self.mlp(
-    #         hidden_states=hidden_states1,
-    #         handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
-    #     )
-
-    #     with torch.cuda.stream(comm_stream):
-    #         mlp_output0["work"].wait()
-    #         mlp_output1["work"].wait()
-
-    #     hidden_states0 = mlp_output0["hidden_states"] + residual0
-    #     hidden_states1 = mlp_output1["hidden_states"] + residual1
-
-    #     hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
-    #     return hidden_states, orig_sequence_mask
-
     def _core_forward(
         self,
         hidden_states: Union[torch.Tensor, TensorPointer],
@@ -908,12 +819,10 @@ def _core_forward(
 
         with torch.cuda.stream(comm_stream):
             mlp_output0["work"].wait()
-            # mlp_output1["work"].wait()
-
             mlp_output0["work"].is_completed()
-            # mlp_output1["work"].is_completed()
 
         torch.cuda.current_stream().wait_stream(comm_stream)
+        # torch.cuda.synchronize()
 
         hidden_states0 = mlp_output0["hidden_states"] + residual0
         hidden_states1 = mlp_output1["hidden_states"] + residual1
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 5ac3bedf..5446f571 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -27,58 +27,39 @@ class DifferentiableIdentity(torch.autograd.Function):
     """All-reduce gradients in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None):
+    def forward(
+        ctx,
+        tensor,
+        group: Optional[ProcessGroup],
+        async_all_reduce: bool,
+        op_name: str = None,
+        comm_stream: torch.cuda.Stream = None,
+    ):
+        ctx.group = group
         ctx.async_all_reduce = async_all_reduce
         ctx.op_name = op_name
-        ctx.group = group
+        ctx.comm_stream = comm_stream
         return tensor
 
     @staticmethod
     def backward(ctx, grad_output):
         # import pydevd
         # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-        # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async
-        # assert ctx.handle_idx is not None
-        group = ctx.group
-
-        # if ctx.handle_idx is not None and "fwd." in ctx.handle_idx:
-        #     handle_idx = ctx.handle_idx.replace("fwd.", "bwd.")
-        #     # if "bwd.layer_mlp_1_batch_1" == handle_idx:
-        #     #     from nanotron.parallel.comm import is_async_comm
-        #     #     async_all_reduce = is_async_comm(handle_idx)
-        #     # else:
-        #     #     async_all_reduce = ctx.async_all_reduce
-        #     # from nanotron.parallel.comm import is_async_comm
-        #     from nanotron.parallel.tensor_parallel.domino import is_async_comm
-
-        #     async_all_reduce = is_async_comm(handle_idx)
-        # else:
-        #     handle_idx = ctx.handle_idx
-        #     async_all_reduce = ctx.async_all_reduce
-
-        # if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
-        #     assert 1 == 1
-
-        op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
-        async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce
-
-        if op_name is not None and "layer_mlp_27_batch_1" in op_name:
-            assert 1 == 1
-
         from nanotron.constants import _AUTOGRAD_RUNS
 
         _AUTOGRAD_RUNS.append(ctx.op_name)
 
-        return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name), None, None, None
-
+        group = ctx.group
 
-def is_last_batch_of_attn(x):
-    import re
+        op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
+        async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce
 
-    pattern = r"layer_attn_\d+_batch_0"
-    if re.match(pattern, x):
-        return True
-    return False
+        return (
+            DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream),
+            None,
+            None,
+            None,
+        )
 
 
 class DifferentiableAllReduceSum(torch.autograd.Function):
@@ -86,47 +67,25 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
 
     @staticmethod
     def forward(
-        ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None
+        ctx,
+        tensor,
+        group: Optional[ProcessGroup],
+        async_all_reduce: bool,
+        op_name: str = None,
+        comm_stream: torch.cuda.Stream = None,
     ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
         ctx.async_all_reduce = async_all_reduce
+        ctx.comm_stream = comm_stream
 
         if group.size() == 1:
             return tensor
 
-        # if handle_idx == "bwd.layer_mlp_1_batch_0":
-        #     assert 1 == 1
-
-        # id(tensor)
-        # if async_all_reduce is True:
-        #     # if isinstance(handle_idx, str):
-        #     #     do_async = is_last_batch_of_attn(handle_idx) is False
-        #     # else:
-        #     #     do_async = async_all_reduce
-        #     # from nanotron.parallel.comm import is_async_comm
-        #     from nanotron.parallel.tensor_parallel.domino import is_async_comm
-
-        #     do_async = is_async_comm(handle_idx)
-
-        #     handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
-        #     if do_async:
-        #         if "bwd" in handle_idx:
-        #             assert 1 == 1
-
-        #         # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
-        #         # if handle_idx is not None and "bwd." in handle_idx:
-        #         #     AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
-        #         # else:
-        #         #     AsyncCommBucket.add(orig_id, handle)
-        #         # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
-        #         assert handle_idx is not None
-        #         AsyncCommBucket.add(handle_idx, handle)
-        # else:
-        #     dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
-        if async_all_reduce:
-            handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
-            AsyncCommBucket.add(op_name, handle)
-        else:
-            dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
+        with torch.cuda.stream(comm_stream):
+            if async_all_reduce:
+                handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
+                AsyncCommBucket.add(op_name, handle)
+            else:
+                dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
 
         return tensor
 
diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
index 7fb68abc..d864bcbf 100644
--- a/src/nanotron/parallel/tensor_parallel/domino.py
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -4,6 +4,11 @@
 
 from nanotron.parallel.comm import AsyncCommBucket
 
+FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
+FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
+BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
+BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}"
+
 
 def is_async_comm(op_name: str):
     """
@@ -40,6 +45,9 @@ def backward(ctx, grad_output):
         if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
             assert 1 == 1
 
+        if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx:
+            assert 1 == 1
+
         if is_async_comm(ctx.wait_handle_idx):
             from nanotron.constants import _AUTOGRAD_RUNS
 
@@ -48,13 +56,12 @@ def backward(ctx, grad_output):
             assert handle is not None
             handle.wait()
             torch.cuda.default_stream().wait_stream(ctx.comm_stream)
-            # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}"
         else:
-
             from nanotron import constants
 
-            # if dist.get_rank() == 0:
-            #     constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
             constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
 
+        # if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx:
+        #     assert AsyncCommBucket._copy_async_op.get(ctx.wait_handle_idx).is_completed() is True
+
         return grad_output, None, None
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index dfec12c0..fb05ebf5 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -564,9 +564,6 @@ def training_step(
             self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
         )
 
-        if dist.get_rank() == 0:
-            assert 1 == 1
-
         # Apply gradient
         self.optimizer.step()
         self.optimizer.zero_grad()
@@ -583,11 +580,8 @@ def training_step(
 
         from nanotron.parallel.comm import AsyncCommBucket
 
-        # import torch.distributed as dist
-
         not_finished = []
         for k, v in AsyncCommBucket._copy_async_op.items():
-            # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}"
             if v.is_completed() is not True:
                 not_finished.append((k, v))
 

From ea09a25ef6c8493cc7be2779ce9a581ba27b7dca Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 12 Feb 2025 13:39:08 +0000
Subject: [PATCH 13/17] refactor

---
 examples/config_llama_domino.yaml             |  2 +-
 src/nanotron/optim/gradient_accumulator.py    |  4 ----
 src/nanotron/parallel/comm.py                 | 10 ++++++++
 .../parallel/pipeline_parallel/engine.py      |  3 ---
 .../distributed_differentiable_primitives.py  |  8 -------
 .../parallel/tensor_parallel/domino.py        | 23 +------------------
 .../parallel/tensor_parallel/functional.py    | 22 +-----------------
 src/nanotron/sanity_checks.py                 |  4 ++++
 src/nanotron/trainer.py                       | 16 ++-----------
 9 files changed, 19 insertions(+), 73 deletions(-)

diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml
index b9811fdd..30f59161 100644
--- a/examples/config_llama_domino.yaml
+++ b/examples/config_llama_domino.yaml
@@ -20,7 +20,7 @@ data_stages:
 general:
   benchmark_csv_path: null
   consumed_train_samples: null
-  ignore_sanity_checks: true
+  ignore_sanity_checks: false
   project: nanotron_domino
   run: config_llama_domino
   seed: 42
diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py
index b5ef7d89..2e940744 100644
--- a/src/nanotron/optim/gradient_accumulator.py
+++ b/src/nanotron/optim/gradient_accumulator.py
@@ -202,10 +202,6 @@ def build_grad_buffers(
         return fp32_grad_buffers, contiguous_buffer_f32_gradients
 
     def backward(self, loss: torch.Tensor):
-        if not isinstance(loss, torch.Tensor):
-            assert 1 == 1
-            raise NotImplementedError("Not implemented yet")
-
         result = loss.backward()
 
         for name, elt in self.fp32_grad_buffers.items():
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 6dbb041f..63736718 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -57,6 +57,16 @@ def wait(tensor_id: int):
         work = AsyncCommBucket._async_op.pop(tensor_id)
         work.wait()
 
+    @staticmethod
+    def is_all_completed() -> bool:
+        assert len(AsyncCommBucket._async_op) == 0, "there are still some async ops haven't executed"
+
+        not_finished = []
+        for k, v in AsyncCommBucket._copy_async_op.items():
+            if v.is_completed() is not True:
+                not_finished.append((k, v))
+        return len(not_finished) == 0
+
     @staticmethod
     def clear_all():
         AsyncCommBucket._async_op.clear()
diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py
index 8160f302..076943c7 100644
--- a/src/nanotron/parallel/pipeline_parallel/engine.py
+++ b/src/nanotron/parallel/pipeline_parallel/engine.py
@@ -84,9 +84,6 @@ def backward(
             if grad_accumulator is None:
                 sum(activations).backward()
             else:
-                # if not isinstance(activations, torch.Tensor):
-                #     raise NotImplementedError("Only support sum of tensors for now")
-
                 grad_accumulator.backward(sum(activations))
 
         # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 5446f571..1254fdb1 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -43,17 +43,9 @@ def forward(
 
     @staticmethod
     def backward(ctx, grad_output):
-        # import pydevd
-        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-        from nanotron.constants import _AUTOGRAD_RUNS
-
-        _AUTOGRAD_RUNS.append(ctx.op_name)
-
         group = ctx.group
-
         op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
         async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce
-
         return (
             DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream),
             None,
diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
index d864bcbf..e35cac06 100644
--- a/src/nanotron/parallel/tensor_parallel/domino.py
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -39,29 +39,8 @@ def forward(ctx, input, wait_handle_idx, comm_stream):
 
     @staticmethod
     def backward(ctx, grad_output):
-        # import pydevd
-        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-
-        if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
-            assert 1 == 1
-
-        if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx:
-            assert 1 == 1
-
         if is_async_comm(ctx.wait_handle_idx):
-            from nanotron.constants import _AUTOGRAD_RUNS
-
-            _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}")
-            handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
-            assert handle is not None
-            handle.wait()
+            AsyncCommBucket.wait(ctx.wait_handle_idx)
             torch.cuda.default_stream().wait_stream(ctx.comm_stream)
-        else:
-            from nanotron import constants
-
-            constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
-
-        # if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx:
-        #     assert AsyncCommBucket._copy_async_op.get(ctx.wait_handle_idx).is_completed() is True
 
         return grad_output, None, None
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index 6d69408c..915a2c31 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -19,6 +19,7 @@
 from torch.nn import functional as F
 
 import nanotron.distributed as dist
+from nanotron.parallel.comm import AsyncCommBucket
 from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
     differentiable_all_reduce_sum,
     differentiable_identity,
@@ -600,30 +601,9 @@ def row_linear(
     out = F.linear(input, weight, bias)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
-        # id(out)
-        # NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
-        if op_name == "fwd.layer_attn_0_batch_0":
-            assert 1 == 1
-
-        if op_name == "fwd.layer_mlp_0_batch_1":
-            assert 1 == 1
-
-        if op_name == "fwd.layer_attn_0_batch_0":
-            assert 1 == 1
-
         out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name)
         if async_all_reduce:
-            from nanotron.parallel.comm import AsyncCommBucket
-
-            # work = AsyncCommBucket.get(orig_out_id)
-            # work = AsyncCommBucket.pop(orig_out_id)
-            # if handle_idx == "fwd.layer_mlp_1_batch_0":
-            if op_name == "fwd.layer_attn_0_batch_0":
-                assert 1 == 1
-
             work = AsyncCommBucket.pop(op_name)
-            assert 1 == 1
         else:
             work = None
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py
index 56ef1e2e..9d1a1589 100644
--- a/src/nanotron/sanity_checks.py
+++ b/src/nanotron/sanity_checks.py
@@ -10,6 +10,7 @@
 from nanotron.models import NanotronModel
 from nanotron.optim.gradient_accumulator import GradientAccumulator
 from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import AsyncCommBucket
 from nanotron.parallel.tied_parameters import get_tied_id_to_param
 
 logger = get_logger(__name__)
@@ -239,6 +240,9 @@ def before_optim_step_sanity_checks(
         # SANITY CHECK: run model specific sanity checks
         unwrapped_model.before_optim_step_sanity_checks()
 
+        # SANITY CHECK: for domino
+        assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing"
+
 
 def after_optim_step_sanity_checks(
     config: Config,
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index fb05ebf5..35686590 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -61,6 +61,7 @@
 from nanotron.models.starcoder2 import Starcoder2ForTraining
 from nanotron.optim.clip_grads import clip_grad_norm
 from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import AsyncCommBucket
 from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
 from nanotron.parallel.parameters import NanotronParameter, sanity_check
 from nanotron.parallel.pipeline_parallel.engine import (
@@ -563,6 +564,7 @@ def training_step(
         before_optim_step_sanity_checks(
             self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
         )
+        AsyncCommBucket.clear_all()
 
         # Apply gradient
         self.optimizer.step()
@@ -578,20 +580,6 @@ def training_step(
 
         self.post_train_step()
 
-        from nanotron.parallel.comm import AsyncCommBucket
-
-        not_finished = []
-        for k, v in AsyncCommBucket._copy_async_op.items():
-            if v.is_completed() is not True:
-                not_finished.append((k, v))
-
-        # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS:
-        #     assert 1 == 1
-
-        assert len(not_finished) == 0, f"len={len(not_finished)}, AsyncCommBucket._copy_async_op: {not_finished}"
-        assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}"
-        AsyncCommBucket.clear_all()
-
         return outputs, loss_avg
 
     def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:

From 7761d82401d7f3e9eb1a106007c0da2e1f6460e0 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 12 Feb 2025 16:35:19 +0000
Subject: [PATCH 14/17] refactor

---
 examples/config_llama_domino.yaml             |   2 +-
 src/nanotron/config/parallelism_config.py     |   8 +
 src/nanotron/constants.py                     |   7 -
 src/nanotron/models/llama.py                  | 171 +++++++++---------
 src/nanotron/parallel/comm.py                 |   6 +-
 .../parallel/tensor_parallel/domino.py        |   8 +-
 src/nanotron/trainer.py                       |   5 +
 tests/helpers/llama.py                        |  49 ++---
 tests/test_base_model.py                      |  32 +++-
 tests/test_domino.py                          |  52 ++++++
 10 files changed, 214 insertions(+), 126 deletions(-)

diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml
index 30f59161..b9811fdd 100644
--- a/examples/config_llama_domino.yaml
+++ b/examples/config_llama_domino.yaml
@@ -20,7 +20,7 @@ data_stages:
 general:
   benchmark_csv_path: null
   consumed_train_samples: null
-  ignore_sanity_checks: false
+  ignore_sanity_checks: true
   project: nanotron_domino
   run: config_llama_domino
   seed: 42
diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py
index 2701bf9c..07688959 100644
--- a/src/nanotron/config/parallelism_config.py
+++ b/src/nanotron/config/parallelism_config.py
@@ -25,6 +25,7 @@ class DominoArgs:
 
     def __post_init__(self):
         assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1"
+        assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino"
 
 
 @dataclass
@@ -68,3 +69,10 @@ def __post_init__(self):
             self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
         if isinstance(self.tp_mode, str):
             self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]
+
+        if self.is_domino_enabled is True:
+            assert self.tp > 1, "Domino requires TP > 1"
+
+    @property
+    def is_domino_enabled(self) -> bool:
+        return True if self.domino else False
diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py
index 3fe440a8..580bd99d 100644
--- a/src/nanotron/constants.py
+++ b/src/nanotron/constants.py
@@ -10,10 +10,3 @@
 
 CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
 MODEL_CONFIG_FILE_NAME = "model_config.json"
-
-
-CUDA_STREAMS = {}
-
-CLOCK = 0
-_AUTOGRAD_RUNS = []
-_NOT_BWD_ASYNC_OPS = []
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index b6785604..ca0b50c6 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -30,8 +30,7 @@
 from nanotron.nn.activations import ACT2FN
 from nanotron.nn.layer_norm import TritonRMSNorm
 from nanotron.parallel import ParallelContext
-
-# from nanotron.parallel.comm import WaitComm
+from nanotron.parallel.comm import CudaStreamManager
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
 from nanotron.parallel.pipeline_parallel.p2p import P2P
@@ -55,8 +54,6 @@
 
 logger = logging.get_logger(__name__)
 
-DOMINO_COMM_STREAM = "domino_comm_stream_{}"
-
 
 class RotaryEmbedding(nn.Module):
     def __init__(self, dim: int, end: int, theta: float = 10000.0):
@@ -248,11 +245,10 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
-            # async_all_reduce=parallel_config.domino.num_input_batches > 1,
         )
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
-    def forward(self, hidden_states, op_name):  # [seq_length, batch_size, hidden_dim]
+    def forward(self, hidden_states, op_name: str = None):  # [seq_length, batch_size, hidden_dim]
         merged_states = self.gate_up_proj(hidden_states, op_name=op_name)
         hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name)
         return {"hidden_states": hidden_states, "work": work}
@@ -347,7 +343,6 @@ def __init__(
         parallel_config: Optional[ParallelismArgs],
         tp_pg: dist.ProcessGroup,
         layer_idx: int,
-        async_all_reduce: bool = False,
     ):
         from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
 
@@ -431,7 +426,6 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication,
-            # async_all_reduce=async_all_reduce,
         )
 
         self.attention = CoreAttention(
@@ -448,7 +442,7 @@ def forward(
         self,
         hidden_states,  # [seq_length, batch_size, hidden_size]
         sequence_mask,  # [batch_size, seq_length]
-        op_name,
+        op_name: str = None,
     ):
         from flash_attn import bert_padding
         from flash_attn.flash_attn_interface import (
@@ -707,7 +701,7 @@ def forward(
         return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask}
 
 
-class LlamaDecoderLayer(nn.Module):
+class _BaseLlamaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
@@ -723,7 +717,6 @@ def __init__(
             parallel_config=parallel_config,
             tp_pg=tp_pg,
             layer_idx=layer_idx,
-            async_all_reduce=parallel_config.domino.num_input_batches > 1,
         )
 
         self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -731,31 +724,93 @@ def __init__(
 
         self.recompute_layer = parallel_config.recompute_layer
         self.parallel_config = parallel_config
+        self.layer_idx = layer_idx
 
-        # if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1:
-        #     from nanotron.parallel.comm import CudaStreamManager
-        #     # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication
-        #     CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device()))
-        num_gpus = torch.cuda.device_count()
-        for i in range(num_gpus):
-            from nanotron import constants
+    def _checkpointed_forward(
+        self,
+        hidden_states: torch.Tensor,
+        sequence_mask: torch.Tensor,
+    ) -> List[torch.Tensor]:
+        return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)
 
-            constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}"))
+    def forward(
+        self,
+        hidden_states: Union[torch.Tensor, TensorPointer],
+        sequence_mask: Union[torch.Tensor, TensorPointer],
+    ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
 
-        self.layer_idx = layer_idx
+        if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
+            hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
+        else:
+            hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)
+
+        return {
+            "hidden_states": hidden_states,
+            "sequence_mask": sequence_mask,
+        }
 
+
+class LlamaDecoderLayer(_BaseLlamaDecoderLayer):
     def _core_forward(
         self,
         hidden_states: Union[torch.Tensor, TensorPointer],
         sequence_mask: Union[torch.Tensor, TensorPointer],
     ) -> List[Union[torch.Tensor, TensorPointer]]:
-        from nanotron import constants
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+
+        output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
+        hidden_states = output["hidden_states"]
+        hidden_states = hidden_states + residual
+
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
+        hidden_states = hidden_states + residual
+
+        return hidden_states, output["sequence_mask"]
 
+
+class Embedding(nn.Module, AttachableStore):
+    def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
+        super().__init__()
+        self.token_embedding = TensorParallelEmbedding(
+            num_embeddings=config.vocab_size,
+            embedding_dim=config.hidden_size,
+            padding_idx=config.pad_token_id,
+            pg=tp_pg,
+            mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
+        )
+        self.pg = tp_pg
+
+    def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor):  # [batch_size, seq_length]
+        store = self.get_local_store()
+        if store is not None:
+            if "past_length" in store:
+                past_length = store["past_length"]
+            else:
+                past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
+
+            cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
+            # Store new past_length in store
+            store["past_length"] = past_length + cumsum_mask[:, -1]
+
+        # Format input in `[seq_length, batch_size]` to support high TP with low batch_size
+        input_ids = input_ids.transpose(0, 1)
+        input_embeds = self.token_embedding(input_ids)
+        return {"input_embeds": input_embeds}
+
+
+class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer):
+    def _core_forward(
+        self,
+        hidden_states: Union[torch.Tensor, TensorPointer],
+        sequence_mask: Union[torch.Tensor, TensorPointer],
+    ) -> List[Union[torch.Tensor, TensorPointer]]:
         num_input_batches = self.parallel_config.domino.num_input_batches
         orig_sequence_mask = sequence_mask
-        comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
+        comm_stream = CudaStreamManager.get(f"comm_stream_{torch.cuda.current_device()}")
 
-        assert num_input_batches == 2
         hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
         sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)
 
@@ -788,20 +843,13 @@ def _core_forward(
         hidden_states0 = attn_output0["hidden_states"] + residual0
         residual0 = hidden_states0
         hidden_states0 = self.post_attention_layernorm(hidden_states0)
-        hidden_states0 = WaitComm.apply(
-            hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream
-        )  # new
+        hidden_states0 = WaitComm.apply(hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream)
 
         mlp_output0 = self.mlp(
             hidden_states=hidden_states0,
             op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
         )
 
-        # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend(
-        #     run_after=attn_output1["hidden_states"],
-        #     run_before=mlp_output0["hidden_states"]
-        # )
-
         with torch.cuda.stream(comm_stream):
             attn_output1["work"].wait()
             attn_output1["work"].is_completed()
@@ -822,70 +870,14 @@ def _core_forward(
             mlp_output0["work"].is_completed()
 
         torch.cuda.current_stream().wait_stream(comm_stream)
-        # torch.cuda.synchronize()
 
         hidden_states0 = mlp_output0["hidden_states"] + residual0
         hidden_states1 = mlp_output1["hidden_states"] + residual1
 
-        # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1)
-
         hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
 
         return hidden_states, orig_sequence_mask
 
-    def _checkpointed_forward(
-        self,
-        hidden_states: torch.Tensor,
-        sequence_mask: torch.Tensor,
-    ) -> List[torch.Tensor]:
-        return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)
-
-    def forward(
-        self,
-        hidden_states: Union[torch.Tensor, TensorPointer],
-        sequence_mask: Union[torch.Tensor, TensorPointer],
-    ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
-
-        if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
-            hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
-        else:
-            hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)
-
-        return {
-            "hidden_states": hidden_states,
-            "sequence_mask": sequence_mask,
-        }
-
-
-class Embedding(nn.Module, AttachableStore):
-    def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
-        super().__init__()
-        self.token_embedding = TensorParallelEmbedding(
-            num_embeddings=config.vocab_size,
-            embedding_dim=config.hidden_size,
-            padding_idx=config.pad_token_id,
-            pg=tp_pg,
-            mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
-        )
-        self.pg = tp_pg
-
-    def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor):  # [batch_size, seq_length]
-        store = self.get_local_store()
-        if store is not None:
-            if "past_length" in store:
-                past_length = store["past_length"]
-            else:
-                past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
-
-            cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
-            # Store new past_length in store
-            store["past_length"] = past_length + cumsum_mask[:, -1]
-
-        # Format input in `[seq_length, batch_size]` to support high TP with low batch_size
-        input_ids = input_ids.transpose(0, 1)
-        input_embeds = self.token_embedding(input_ids)
-        return {"input_embeds": input_embeds}
-
 
 class LlamaModel(nn.Module):
     """Build pipeline graph"""
@@ -896,6 +888,8 @@ def __init__(
         parallel_context: ParallelContext,
         parallel_config: Optional[ParallelismArgs],
     ):
+        # from nanotron.parallel.tensor_parallel.domino import DominoLlamaDecoderLayer
+
         super().__init__()
 
         # Declare all the nodes
@@ -931,7 +925,7 @@ def __init__(
             [
                 PipelineBlock(
                     p2p=self.p2p,
-                    module_builder=LlamaDecoderLayer,
+                    module_builder=DominoLlamaDecoderLayer if parallel_config.is_domino_enabled else LlamaDecoderLayer,
                     module_kwargs={
                         "config": config,
                         "parallel_config": parallel_config,
@@ -992,7 +986,6 @@ def forward_with_hidden_states(
         input_mask: Union[torch.Tensor, TensorPointer],  # [batch_size, seq_length]
     ):
         # all tensors are optional as most ranks don't need anything from the dataloader.
-
         output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
 
         hidden_encoder_states = {
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 63736718..248b966c 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -8,12 +8,14 @@ class CudaStreamManager:
     _streams: Dict[str, "torch.cuda.Stream"] = {}
 
     @staticmethod
-    def create(name: str):
+    def create(name: str, device: torch.device = None):
         assert name not in CudaStreamManager._streams
-        CudaStreamManager._streams[name] = torch.cuda.Stream()
+        CudaStreamManager._streams[name] = torch.cuda.Stream(device=device)
 
     @staticmethod
     def get(name: str):
+        if name not in CudaStreamManager._streams:
+            CudaStreamManager.create(name)
         return CudaStreamManager._streams.get(name)
 
     @contextmanager
diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
index e35cac06..26050c3a 100644
--- a/src/nanotron/parallel/tensor_parallel/domino.py
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -4,6 +4,11 @@
 
 from nanotron.parallel.comm import AsyncCommBucket
 
+# from nanotron.models.llama import _BaseLlamaDecoderLayer
+# from nanotron.parallel.pipeline_parallel.block import TensorPointer
+# from nanotron.parallel.comm import CudaStreamManager
+
+
 FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
 FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
 BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
@@ -17,10 +22,7 @@ def is_async_comm(op_name: str):
     for the backward pass: the first micro-batch of the attention layer
     """
     NON_ASYNC_HANDLE_IDX = [
-        # "fwd.layer_attn_{}_batch_0",
-        # "fwd.layer_mlp_{}_batch_0",
         "fwd.layer_mlp_{}_batch_1",
-        # "bwd.layer_mlp_{}_batch_1",
         "bwd.layer_attn_{}_batch_0",
     ]
 
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index 35686590..f5370525 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -444,6 +444,11 @@ def train(
         # free memory
         gc.collect()
         torch.cuda.empty_cache()
+
+        # num_gpus = torch.cuda.device_count()
+        # for i in range(num_gpus):
+        #     CudaStreamManager.create(f"comm_stream_{i}", device=torch.device(f"cuda:{i}"))
+
         with prof:
             for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1):
                 if isinstance(prof, torch.profiler.profile):
diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py
index 3f94031f..8aae7669 100644
--- a/tests/helpers/llama.py
+++ b/tests/helpers/llama.py
@@ -1,5 +1,6 @@
 import torch
 from nanotron.config import (
+    AdamWOptimizerArgs,
     AllForwardAllBackwardPipelineEngine,
     CheckpointsArgs,
     Config,
@@ -46,7 +47,19 @@
 )
 
 
-def get_llama_training_config(model_config: ModelArgs):
+def get_parallel_config(parallel_context: ParallelContext):
+    return ParallelismArgs(
+        dp=parallel_context.data_parallel_size,
+        pp=parallel_context.pipeline_parallel_size,
+        tp=parallel_context.tensor_parallel_size,
+        expert_parallel_size=parallel_context.expert_parallel_size,
+        pp_engine=AllForwardAllBackwardPipelineEngine(),
+        tp_mode=TensorParallelLinearMode.ALL_REDUCE,
+        tp_linear_async_communication=False,
+    )
+
+
+def get_llama_training_config(model_config: ModelArgs, parallel_context):
     return Config(
         model=model_config,
         general=GeneralArgs(project="unittest", run="sanity_llama", seed=42),
@@ -54,25 +67,20 @@ def get_llama_training_config(model_config: ModelArgs):
             checkpoints_path="./checkpoints",
             checkpoint_interval=10,
         ),
-        parallelism=ParallelismArgs(
-            dp=1,
-            pp=1,
-            tp=2,
-            expert_parallel_size=2,
-            pp_engine="1f1b",
-            tp_mode="ALL_REDUCE",
-            tp_linear_async_communication=False,
-        ),
+        parallelism=get_parallel_config(parallel_context),
         tokenizer=TokenizerArgs("gpt2"),
         optimizer=OptimizerArgs(
             zero_stage=0,
             weight_decay=0.01,
             clip_grad=1.0,
             accumulate_grad_in_fp32=False,
-            adam_eps=1e-08,
-            adam_beta1=0.9,
-            adam_beta2=0.95,
-            torch_adam_is_fused=True,
+            optimizer_factory=AdamWOptimizerArgs(
+                adam_eps=1e-08,
+                adam_beta1=0.9,
+                adam_beta2=0.95,
+                torch_adam_is_fused=True,
+                name="adamW",
+            ),
             learning_rate_scheduler=LRSchedulerArgs(
                 learning_rate=3e-4,
                 lr_warmup_steps=100,
@@ -103,7 +111,10 @@ def get_llama_training_config(model_config: ModelArgs):
 
 
 def create_llama_from_config(
-    model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext
+    model_config: LlamaConfig,
+    parallel_config: ParallelismArgs,
+    device: torch.device,
+    parallel_context: ParallelContext,
 ) -> LlamaForTraining:
 
     """
@@ -114,14 +125,6 @@ def create_llama_from_config(
     the model created will have random weights.
     """
 
-    parallel_config = ParallelismArgs(
-        dp=parallel_context.data_parallel_size,
-        pp=parallel_context.pipeline_parallel_size,
-        tp=parallel_context.tensor_parallel_size,
-        pp_engine=AllForwardAllBackwardPipelineEngine(),
-        tp_mode=TensorParallelLinearMode.ALL_REDUCE,
-        tp_linear_async_communication=False,
-    )
     model = build_model(
         model_builder=lambda: LlamaForTraining(
             config=model_config,
diff --git a/tests/test_base_model.py b/tests/test_base_model.py
index b4759905..410e302c 100644
--- a/tests/test_base_model.py
+++ b/tests/test_base_model.py
@@ -10,7 +10,6 @@
 
 
 @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)])
-@pytest.mark.skip
 @rerun_if_address_is_in_use()
 def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int):
     model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
@@ -43,3 +42,34 @@ def _test_get_named_modules_in_pp_rank(
         # not PipelineBlock
         assert isinstance(module, nn.Module)
         assert name not in modules_that_not_in_current_pp_rank
+
+
+@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 1)])
+@rerun_if_address_is_in_use()
+def test_llama_model(tp: int, dp: int, pp: int):
+    BATCH_SIZE, SEQ_LEN = 10, 128
+    model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
+    config = get_llama_training_config(model_args)
+
+    init_distributed(tp=tp, dp=dp, pp=pp)(_test_llama_model)(config=config, batch_size=BATCH_SIZE, seq_len=SEQ_LEN)
+
+
+def _test_llama_model(
+    parallel_context: ParallelContext,
+    config: Config,
+    batch_size: int,
+    seq_len: int,
+):
+    llama_model = create_llama_from_config(
+        model_config=config.model.model_config,
+        device=torch.device("cuda"),
+        parallel_context=parallel_context,
+    )
+    llama_model.init_model_randomly(config=config)
+
+    input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda")
+    input_mask = torch.ones_like(input_ids)
+    outputs = llama_model(input_ids, input_mask, input_mask, input_mask)
+
+    assert list(outputs.keys()) == ["loss"]
+    assert isinstance(outputs["loss"], torch.Tensor)
diff --git a/tests/test_domino.py b/tests/test_domino.py
index 8f474ff8..977e5407 100644
--- a/tests/test_domino.py
+++ b/tests/test_domino.py
@@ -1,4 +1,14 @@
+from copy import deepcopy
+
 import pytest
+import torch
+from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config
+from helpers.utils import init_distributed, rerun_if_address_is_in_use
+from nanotron.config import ModelArgs, RandomInit
+from nanotron.config.parallelism_config import DominoArgs
+from nanotron.models.llama import DominoLlamaDecoderLayer
+from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import AsyncCommBucket
 from nanotron.parallel.tensor_parallel.domino import is_async_comm
 
 
@@ -17,3 +27,45 @@
 )
 def test_is_async_comm(op_name, expected):
     assert is_async_comm(op_name) == expected
+
+
+@pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)])
+@rerun_if_address_is_in_use()
+def test_domino_model(tp: int, dp: int, pp: int):
+    BATCH_SIZE, SEQ_LEN = 10, 128
+
+    model_config = deepcopy(TINY_LLAMA_CONFIG)
+    model_config.num_hidden_layers = 28
+    model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
+
+    init_distributed(tp=tp, dp=dp, pp=pp)(_test_domino_model)(
+        model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN
+    )
+
+
+def _test_domino_model(
+    parallel_context: ParallelContext,
+    model_args: ModelArgs,
+    batch_size: int,
+    seq_len: int,
+):
+    config = get_llama_training_config(model_args, parallel_context)
+    config.parallelism.domino = DominoArgs(num_input_batches=2)
+
+    llama_model = create_llama_from_config(
+        model_config=config.model.model_config,
+        parallel_config=config.parallelism,
+        device=torch.device("cuda"),
+        parallel_context=parallel_context,
+    )
+    llama_model.init_model_randomly(config=config)
+
+    for m in llama_model.model.decoder:
+        assert isinstance(m.pp_block, DominoLlamaDecoderLayer)
+
+    input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda")
+    input_mask = torch.ones_like(input_ids)
+    outputs = llama_model(input_ids, input_mask, input_mask, input_mask)
+
+    assert isinstance(outputs["loss"], torch.Tensor)
+    assert AsyncCommBucket.is_all_completed()

From b11e48fc0de4a86855d27c116365604d28832466 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Wed, 12 Feb 2025 16:37:06 +0000
Subject: [PATCH 15/17] remove dependency.py and comments

---
 src/nanotron/parallel/dependency.py           | 102 ------------------
 .../parallel/tensor_parallel/domino.py        |   4 -
 src/nanotron/trainer.py                       |   4 -
 3 files changed, 110 deletions(-)
 delete mode 100644 src/nanotron/parallel/dependency.py

diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py
deleted file mode 100644
index 6a633d8a..00000000
--- a/src/nanotron/parallel/dependency.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from typing import Dict, Tuple
-
-import torch
-from torch import Tensor
-
-_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
-
-
-def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
-    """Gets a phony. Phony is tensor without space. It is useful to make
-    arbitrary dependency in a autograd graph because it doesn't require any
-    gradient accumulation.
-
-    .. note::
-
-        Phonies for each device are cached. If an autograd function gets a phony
-        internally, the phony must be detached to be returned. Otherwise, the
-        autograd engine will mutate the cached phony in-place::
-
-            class Phonify(torch.autograd.Function):
-                @staticmethod
-                def forward(ctx, input):
-                    phony = get_phony(input.device, requires_grad=False)
-                    return phony.detach()  # detach() is necessary.
-
-    """
-    key = (device, requires_grad)
-
-    try:
-        phony = _phonies[key]
-    except KeyError:
-        with torch.cuda.stream(torch.cuda.default_stream(device)):
-            phony = torch.empty(0, device=device, requires_grad=requires_grad)
-
-        _phonies[key] = phony
-
-    return phony
-
-
-def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
-    """Branches out from an autograd lane of the given tensor."""
-    if torch.is_grad_enabled() and input.requires_grad:
-        input, phony = Fork.apply(input)
-    else:
-        phony = get_phony(input.device, requires_grad=False)
-
-    return input, phony
-
-
-class Fork(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore
-        phony = get_phony(input.device, requires_grad=False)
-        return input, phony.detach()
-
-    @staticmethod
-    def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor:  # type: ignore
-        # import pydevd
-        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-        return grad_input
-
-
-def join(input: Tensor, phony: Tensor) -> Tensor:
-    """Merges two autograd lanes."""
-    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
-        input = Join.apply(input, phony)
-
-    return input
-
-
-class Join(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor:  # type: ignore
-        return input
-
-    @staticmethod
-    def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]:  # type: ignore
-        # import pydevd
-        # pydevd.settrace(suspend=False, trace_only_current_thread=True)
-        return grad_input, None
-
-
-# def depend(fork_from, join_to) -> None:
-#     # Ensure that batches[i-1] is executed after batches[i] in
-#     # # backpropagation by an explicit dependency.
-#     # if i != 0:
-#     #     depend(batches[i-1], batches[i])
-#     # depend(run_after, run_before)
-#     fork_from, phony = fork(fork_from)
-#     join_to = join(join_to, phony)
-#     return fork_from, join_to
-
-
-def depend(run_after, run_before) -> None:
-    # Ensure that batches[i-1] is executed after batches[i] in
-    # # backpropagation by an explicit dependency.
-    # if i != 0:
-    #     depend(batches[i-1], batches[i])
-    # depend(run_after, run_before)
-    run_after, phony = fork(run_after)
-    run_before = join(run_before, phony)
-    return run_after, run_before
diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
index 26050c3a..d41fceb8 100644
--- a/src/nanotron/parallel/tensor_parallel/domino.py
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -4,10 +4,6 @@
 
 from nanotron.parallel.comm import AsyncCommBucket
 
-# from nanotron.models.llama import _BaseLlamaDecoderLayer
-# from nanotron.parallel.pipeline_parallel.block import TensorPointer
-# from nanotron.parallel.comm import CudaStreamManager
-
 
 FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
 FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index f5370525..d0eddf39 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -445,10 +445,6 @@ def train(
         gc.collect()
         torch.cuda.empty_cache()
 
-        # num_gpus = torch.cuda.device_count()
-        # for i in range(num_gpus):
-        #     CudaStreamManager.create(f"comm_stream_{i}", device=torch.device(f"cuda:{i}"))
-
         with prof:
             for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1):
                 if isinstance(prof, torch.profiler.profile):

From 893ff076348a8bd537b69356cafebbf2d65ea873 Mon Sep 17 00:00:00 2001
From: Phuc Nguyen <b3f0cus@icloud.com>
Date: Fri, 14 Feb 2025 15:56:33 +0000
Subject: [PATCH 16/17] add unit tests for async bucket, and WaitComm

---
 src/nanotron/models/llama.py                  |   3 +-
 src/nanotron/parallel/comm.py                 |  51 +++++---
 .../parallel/tensor_parallel/domino.py        |  21 ----
 tests/test_comm.py                            | 117 ++++++++++++++++++
 tests/test_domino.py                          |   2 +-
 5 files changed, 156 insertions(+), 38 deletions(-)
 create mode 100644 tests/test_comm.py

diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index ca0b50c6..431cfc2d 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -30,7 +30,7 @@
 from nanotron.nn.activations import ACT2FN
 from nanotron.nn.layer_norm import TritonRMSNorm
 from nanotron.parallel import ParallelContext
-from nanotron.parallel.comm import CudaStreamManager
+from nanotron.parallel.comm import CudaStreamManager, WaitComm
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
 from nanotron.parallel.pipeline_parallel.p2p import P2P
@@ -39,7 +39,6 @@
     BWD_MLP_HANDLE_IDX,
     FWD_ATTN_HANDLE_IDX,
     FWD_MLP_HANDLE_IDX,
-    WaitComm,
 )
 from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
 from nanotron.parallel.tensor_parallel.nn import (
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
index 248b966c..1f3b043d 100644
--- a/src/nanotron/parallel/comm.py
+++ b/src/nanotron/parallel/comm.py
@@ -3,6 +3,8 @@
 
 import torch
 
+from nanotron.parallel.tensor_parallel.domino import is_async_comm
+
 
 class CudaStreamManager:
     _streams: Dict[str, "torch.cuda.Stream"] = {}
@@ -38,30 +40,35 @@ class AsyncCommBucket:
     _copy_async_op: Dict[int, "dist.Work"] = {}
 
     @staticmethod
-    def add(tensor_id: int, work: "dist.Work"):
-        assert (
-            tensor_id not in AsyncCommBucket._async_op
-        ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}"
-        AsyncCommBucket._async_op[tensor_id] = work
-        AsyncCommBucket._copy_async_op[tensor_id] = work
+    def add(op_name: int, work: "dist.Work"):
+        assert op_name not in AsyncCommBucket._async_op, f"Operation with name: {op_name} already exists"
+        AsyncCommBucket._async_op[op_name] = work
+        AsyncCommBucket._copy_async_op[op_name] = work
 
     @staticmethod
-    def get(tensor_id: int):
-        return AsyncCommBucket._async_op.get(tensor_id)
+    def get(op_name: int):
+        if op_name not in AsyncCommBucket._async_op:
+            raise KeyError(f"Operation with name: {op_name} doesn't exist")
+
+        return AsyncCommBucket._async_op.get(op_name)
 
     @staticmethod
-    def pop(tensor_id: int):
-        assert tensor_id in AsyncCommBucket._async_op, f"tensor_id: {tensor_id}"
-        return AsyncCommBucket._async_op.pop(tensor_id)
+    def pop(op_name: int):
+        if op_name not in AsyncCommBucket._async_op:
+            raise KeyError(f"Operation with name: {op_name} doesn't exist")
+
+        return AsyncCommBucket._async_op.pop(op_name)
 
     @staticmethod
-    def wait(tensor_id: int):
-        work = AsyncCommBucket._async_op.pop(tensor_id)
+    def wait(op_name: int):
+        """Wait and remove the operation from the bucket"""
+        work = AsyncCommBucket.pop(op_name)
         work.wait()
 
     @staticmethod
     def is_all_completed() -> bool:
-        assert len(AsyncCommBucket._async_op) == 0, "there are still some async ops haven't executed"
+        if not len(AsyncCommBucket._async_op) == 0:
+            return False
 
         not_finished = []
         for k, v in AsyncCommBucket._copy_async_op.items():
@@ -73,3 +80,19 @@ def is_all_completed() -> bool:
     def clear_all():
         AsyncCommBucket._async_op.clear()
         AsyncCommBucket._copy_async_op.clear()
+
+
+class WaitComm(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input: torch.Tensor, wait_handle_idx: str, comm_stream: torch.cuda.Stream):
+        ctx.wait_handle_idx = wait_handle_idx
+        ctx.comm_stream = comm_stream
+        return input
+
+    @staticmethod
+    def backward(ctx, grad_output: torch.Tensor):
+        if is_async_comm(ctx.wait_handle_idx):
+            AsyncCommBucket.wait(ctx.wait_handle_idx)
+            torch.cuda.default_stream().wait_stream(ctx.comm_stream)
+
+        return grad_output, None, None
diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
index d41fceb8..0f4c5a12 100644
--- a/src/nanotron/parallel/tensor_parallel/domino.py
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -1,10 +1,5 @@
 import re
 
-import torch
-
-from nanotron.parallel.comm import AsyncCommBucket
-
-
 FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
 FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
 BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
@@ -26,19 +21,3 @@ def is_async_comm(op_name: str):
     regex = re.compile("^(" + "|".join(patterns) + ")$")  # Combine patterns into a single regex
     not_async = bool(regex.match(op_name))
     return not not_async
-
-
-class WaitComm(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, input, wait_handle_idx, comm_stream):
-        ctx.wait_handle_idx = wait_handle_idx
-        ctx.comm_stream = comm_stream
-        return input
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        if is_async_comm(ctx.wait_handle_idx):
-            AsyncCommBucket.wait(ctx.wait_handle_idx)
-            torch.cuda.default_stream().wait_stream(ctx.comm_stream)
-
-        return grad_output, None, None
diff --git a/tests/test_comm.py b/tests/test_comm.py
new file mode 100644
index 00000000..286f3ebe
--- /dev/null
+++ b/tests/test_comm.py
@@ -0,0 +1,117 @@
+import pytest
+import torch
+import torch.distributed as dist
+from helpers.utils import (
+    init_distributed,
+    rerun_if_address_is_in_use,
+)
+from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import AsyncCommBucket, WaitComm
+
+
+class MockWork:
+    def __init__(self):
+        self.completed = False
+        self.wait_called = False
+
+    def wait(self):
+        self.wait_called = True
+        self.completed = True
+
+    def is_completed(self):
+        return self.completed
+
+
+@rerun_if_address_is_in_use()
+def test_add_async_op_to_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)()
+
+
+def _test_add_async_op_to_bucket(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    tensor = torch.randn(1, device="cuda")
+    work = dist.all_reduce(tensor, async_op=True)
+
+    AsyncCommBucket.add(OP_NAME, work)
+
+    assert AsyncCommBucket.get(OP_NAME) is work
+
+
+@rerun_if_address_is_in_use()
+def test_wait_async_op_to_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)()
+
+
+def _test_wait_async_op_to_bucket(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    work = MockWork()
+
+    AsyncCommBucket.add(OP_NAME, work)
+    assert work.is_completed() is False
+
+    AsyncCommBucket.wait(OP_NAME)
+    assert work.is_completed()
+    with pytest.raises(KeyError):
+        AsyncCommBucket.get(OP_NAME)
+
+
+@rerun_if_address_is_in_use()
+def test_is_all_completed_in_async_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)()
+
+
+def _test_wait_async_op_to_bucket(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    work = MockWork()
+
+    AsyncCommBucket.add(OP_NAME, work)
+    assert AsyncCommBucket.is_all_completed() is False
+
+    AsyncCommBucket.wait(OP_NAME)
+    assert AsyncCommBucket.is_all_completed() is True
+
+
+@rerun_if_address_is_in_use()
+def test_clear_ops_in_async_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_clear_ops_in_async_bucket)()
+
+
+def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext):
+    tensor1 = torch.randn(1, device="cuda")
+    tensor2 = torch.randn(1, device="cuda")
+    tensor3 = torch.randn(1, device="cuda")
+
+    AsyncCommBucket.add("test1", dist.all_reduce(tensor1, async_op=True))
+    AsyncCommBucket.add("test2", dist.all_reduce(tensor2, async_op=True))
+    AsyncCommBucket.add("test3", dist.all_reduce(tensor3, async_op=True))
+
+    assert AsyncCommBucket.is_all_completed() is False
+
+    AsyncCommBucket.clear_all()
+    assert AsyncCommBucket.is_all_completed() is True
+    with pytest.raises(KeyError):
+        AsyncCommBucket.get("test1")
+
+
+@rerun_if_address_is_in_use()
+def test_wait_comm():
+    init_distributed(tp=2, dp=1, pp=1)(_test_wait_comm)()
+
+
+def _test_wait_comm(parallel_context: ParallelContext):
+    tensor = torch.randn(1, device="cuda", requires_grad=True)
+    OP_NAME = "test"
+
+    comm_stream = torch.cuda.Stream()
+
+    with torch.cuda.stream(comm_stream):
+        work = MockWork()
+        AsyncCommBucket.add(OP_NAME, work)
+
+    output = WaitComm.apply(tensor, OP_NAME, comm_stream)
+    assert work.is_completed() is False
+
+    # NOTE: we test that it waits for the async op to complete
+    # automatically in autograd
+    (output + 1).backward()
+    assert work.is_completed()
diff --git a/tests/test_domino.py b/tests/test_domino.py
index 977e5407..44d9d98a 100644
--- a/tests/test_domino.py
+++ b/tests/test_domino.py
@@ -68,4 +68,4 @@ def _test_domino_model(
     outputs = llama_model(input_ids, input_mask, input_mask, input_mask)
 
     assert isinstance(outputs["loss"], torch.Tensor)
-    assert AsyncCommBucket.is_all_completed()
+    assert AsyncCommBucket.is_all_completed() is True

From 14a0e4e6add14a0c7b318c03583e240911ce02ce Mon Sep 17 00:00:00 2001
From: "phuc.nguyen@huggingface.co" <phuc_nguyen@ip-26-0-160-40.ec2.internal>
Date: Fri, 21 Feb 2025 12:51:02 +0000
Subject: [PATCH 17/17] directly take is_async_comm from op_name

---
 src/nanotron/models/llama.py                  |  1 +
 .../distributed_differentiable_primitives.py  | 21 +++++++------------
 .../parallel/tensor_parallel/functional.py    |  9 ++++----
 src/nanotron/parallel/tensor_parallel/nn.py   |  5 +----
 src/nanotron/sanity_checks.py                 |  5 ++++-
 src/nanotron/trainer.py                       | 10 +++++++++
 6 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index 431cfc2d..ea27a97e 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -866,6 +866,7 @@ def _core_forward(
 
         with torch.cuda.stream(comm_stream):
             mlp_output0["work"].wait()
+            assert 1 == 1
             mlp_output0["work"].is_completed()
 
         torch.cuda.current_stream().wait_stream(comm_stream)
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index 1254fdb1..3ba071eb 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -31,12 +31,10 @@ def forward(
         ctx,
         tensor,
         group: Optional[ProcessGroup],
-        async_all_reduce: bool,
         op_name: str = None,
         comm_stream: torch.cuda.Stream = None,
     ):
         ctx.group = group
-        ctx.async_all_reduce = async_all_reduce
         ctx.op_name = op_name
         ctx.comm_stream = comm_stream
         return tensor
@@ -45,9 +43,8 @@ def forward(
     def backward(ctx, grad_output):
         group = ctx.group
         op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
-        async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce
         return (
-            DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream),
+            DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.comm_stream),
             None,
             None,
             None,
@@ -62,16 +59,16 @@ def forward(
         ctx,
         tensor,
         group: Optional[ProcessGroup],
-        async_all_reduce: bool,
         op_name: str = None,
         comm_stream: torch.cuda.Stream = None,
     ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
-        ctx.async_all_reduce = async_all_reduce
+        ctx.op_name = op_name
         ctx.comm_stream = comm_stream
 
         if group.size() == 1:
             return tensor
 
+        async_all_reduce = is_async_comm(op_name) if op_name is not None else False
         with torch.cuda.stream(comm_stream):
             if async_all_reduce:
                 handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
@@ -165,16 +162,12 @@ def backward(ctx, grad_output):
 # -----------------
 
 
-def differentiable_identity(
-    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None
-):
-    return DifferentiableIdentity.apply(tensor, group, async_all_reduce, op_name)
+def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, op_name: str = None):
+    return DifferentiableIdentity.apply(tensor, group, op_name)
 
 
-def differentiable_all_reduce_sum(
-    tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None
-):
-    return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, op_name)
+def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, op_name: str = None):
+    return DifferentiableAllReduceSum.apply(tensor, group, op_name)
 
 
 def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index 915a2c31..57a21b58 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -25,6 +25,7 @@
     differentiable_identity,
     differentiable_reduce_scatter_sum,
 )
+from nanotron.parallel.tensor_parallel.domino import is_async_comm
 from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
 from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1
 
@@ -437,14 +438,13 @@ def column_linear(
     tp_mode: TensorParallelLinearMode,
     async_communication: bool,
     tp_recompute_allgather: bool = True,
-    async_all_reduce: bool = False,
     op_name: Optional[str] = None,
 ):
     if async_communication:
         return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, op_name=op_name)
+        input = differentiable_identity(input, group=group, op_name=op_name)
         return F.linear(input, weight, bias)
     if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
@@ -592,7 +592,6 @@ def row_linear(
     tp_mode: TensorParallelLinearMode,
     # TODO(xrsrke): use less confusing names for these arguments
     async_communication: bool,
-    async_all_reduce: bool,
     op_name: Optional[str] = None,
 ) -> Tuple[torch.Tensor, Optional[torch.Future]]:
     if async_communication:
@@ -601,7 +600,9 @@ def row_linear(
     out = F.linear(input, weight, bias)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name)
+        out = differentiable_all_reduce_sum(out, group=group, op_name=op_name)
+
+        async_all_reduce = is_async_comm(op_name) if op_name is not None else False
         if async_all_reduce:
             work = AsyncCommBucket.pop(op_name)
         else:
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index 2e6fd5a4..41386d38 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -31,7 +31,6 @@
     differentiable_identity,
     differentiable_reduce_scatter_sum,
 )
-from nanotron.parallel.tensor_parallel.domino import is_async_comm
 from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
 from nanotron.parallel.tensor_parallel.functional import (
     column_linear,
@@ -95,7 +94,6 @@ def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor:
             tp_mode=self.mode,
             async_communication=self.async_communication,
             tp_recompute_allgather=self.tp_recompute_allgather,
-            async_all_reduce=False if op_name is None else is_async_comm(op_name),
             op_name=op_name,
         )
 
@@ -170,7 +168,6 @@ def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor:
             group=self.pg,
             tp_mode=self.mode,
             async_communication=self.async_communication,
-            async_all_reduce=False if op_name is None else is_async_comm(op_name),
             op_name=op_name,
         )
 
@@ -296,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
             out = out * (~input_mask[..., None])
 
         if self.mode is TensorParallelLinearMode.ALL_REDUCE:
-            out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
+            out = differentiable_all_reduce_sum(out, group=self.pg)
         elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
             out = differentiable_reduce_scatter_sum(out, group=self.pg)
         else:
diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py
index 9d1a1589..2a02d830 100644
--- a/src/nanotron/sanity_checks.py
+++ b/src/nanotron/sanity_checks.py
@@ -241,7 +241,10 @@ def before_optim_step_sanity_checks(
         unwrapped_model.before_optim_step_sanity_checks()
 
         # SANITY CHECK: for domino
-        assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing"
+        try:
+            assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing"
+        except:
+            assert 1 == 1
 
 
 def after_optim_step_sanity_checks(
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index d0eddf39..7cac06c6 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -496,6 +496,10 @@ def training_step(
             grad_accumulator=self.grad_accumulator,
         )
 
+        torch.cuda.synchronize()
+        time.sleep(2)
+        torch.cuda.synchronize()
+
         if self.iteration_step < self.initial_iter_step + 5:
             log_memory(logger=logger)
 
@@ -565,6 +569,12 @@ def training_step(
         before_optim_step_sanity_checks(
             self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
         )
+
+        # try:
+        #     assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing"
+        # except:
+        #     assert 1 == 1
+        assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing"
         AsyncCommBucket.clear_all()
 
         # Apply gradient