diff --git a/examples/domino/config_domino.yaml b/examples/domino/config_domino.yaml
new file mode 100644
index 00000000..9887ad26
--- /dev/null
+++ b/examples/domino/config_domino.yaml
@@ -0,0 +1,108 @@
+checkpoints:
+  checkpoint_interval: 10000
+  checkpoints_path: checkpoints
+  checkpoints_path_is_shared_file_system: false
+  resume_checkpoint_path: null
+  load_lr_scheduler: false
+  load_optimizer: false
+  save_final_state: true
+  save_initial_state: false
+data_stages:
+- data:
+    dataset:
+      dataset_folder:
+      - /fsx/loubna/datasets/llama_tokenized/other_sources/wiki
+      token_size_in_bytes: 4
+      tokenizer_name: meta-llama/Llama-3.2-1B
+      vocab_size: 128256
+    num_loading_workers: 8
+    seed: 42
+  name: Training Stage
+  start_training_step: 1
+general:
+  benchmark_csv_path: null
+  consumed_train_samples: null
+  ignore_sanity_checks: true
+  project: nanotron_domino
+  run: domino_config
+  seed: 6
+  step: null
+logging:
+  iteration_step_info_interval: 1
+  log_level: info
+  log_level_replica: info
+model:
+  ddp_bucket_cap_mb: 25
+  dtype: bfloat16
+  init_method:
+    std: 0.041666666666666664
+  make_vocab_size_divisible_by: 1
+  model_config:
+    bos_token_id: 128000
+    eos_token_id: 128001
+    hidden_act: silu
+    hidden_size: 4096
+    initializer_range: 0.02
+    intermediate_size: 14336
+    is_llama_config: true
+    max_position_embeddings: 4096
+    num_attention_heads: 32
+    num_hidden_layers: 20
+    num_key_value_heads: 8
+    pad_token_id: null
+    pretraining_tp: 2
+    rms_norm_eps: 1.0e-05
+    rope_interleaved: false
+    rope_scaling:
+      factor: 32.0
+      high_freq_factor: 4.0
+      low_freq_factor: 1.0
+      original_max_position_embeddings: 4096
+      rope_type: llama3
+    rope_theta: 500000.0
+    tie_word_embeddings: true
+    use_cache: true
+    vocab_size: 128256
+optimizer:
+  accumulate_grad_in_fp32: true
+  clip_grad: 1.0
+  learning_rate_scheduler:
+    learning_rate: 0.00005
+    lr_decay_starting_step: 50000
+    lr_decay_steps: 10000
+    lr_decay_style: linear
+    lr_warmup_steps: 1000
+    lr_warmup_style: linear
+    min_decay_lr: 0
+  optimizer_factory:
+    adam_beta1: 0.9
+    adam_beta2: 0.95
+    adam_eps: 1.0e-08
+    name: adamW
+    torch_adam_is_fused: true
+  weight_decay: 0.01
+  zero_stage: 1
+parallelism:
+  dp: 1
+  expert_parallel_size: 1
+  pp: 1
+  pp_engine: 1f1b
+  recompute_layer: false
+  tp: 8
+  tp_linear_async_communication: false
+  tp_mode: ALL_REDUCE
+  tp_recompute_allgather: false
+  domino:
+    num_input_batches: 2
+tokenizer:
+  tokenizer_max_length: null
+  tokenizer_name_or_path: meta-llama/Llama-3.2-1B
+  tokenizer_revision: null
+tokens:
+  batch_accumulation_per_replica: 2
+  limit_test_batches: 0
+  limit_val_batches: 0
+  micro_batch_size: 8
+  sequence_length: 4096
+  train_steps: 15000
+  val_check_interval: -1
diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py
index 48aa941e..72f667ab 100644
--- a/src/nanotron/config/parallelism_config.py
+++ b/src/nanotron/config/parallelism_config.py
@@ -11,6 +11,23 @@
 from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
 
 
+@dataclass
+class DominoArgs:
+    """
+    Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping
+    https://arxiv.org/abs/2409.15241
+    """
+
+    # NOTE: if the number of input batches is 1,
+    # it's equivalent to non-domino mode
+    # so if you want to enable domino mode, set this to > 1
+    num_input_batches: int
+
+    def __post_init__(self):
+        assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1"
+        assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino"
+
+
 @dataclass
 class ParallelismArgs:
     """Arguments related to TP/PP/DP
@@ -39,6 +56,7 @@ class ParallelismArgs:
 
     expert_parallel_size: int = 1
     context_parallel_size: int = 1
+    domino: Optional[DominoArgs] = None
 
     def __post_init__(self):
         # Conservative defaults
@@ -53,3 +71,19 @@ def __post_init__(self):
             self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
         if isinstance(self.tp_mode, str):
             self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]
+
+        if self.is_domino_enabled is True:
+            assert self.tp > 1, "Domino requires TP > 1"
+            # NOTE: For DoMiNo since we overlapping the communication
+            # so it doesnt matter whether it's all_reduce or reduce_scatter
+            # so we just support and tested with all_reduce up to now
+            # but in principle, it should work with reduce_scatter as well
+            assert (
+                self.tp_linear_async_communication is False
+            ), "Domino requires TP linear async communication to be False"
+            # TODO: support REDUCE_SCATTER mode for Domino
+            assert self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, "Domino requires TP mode to be ALL_REDUCE"
+
+    @property
+    def is_domino_enabled(self) -> bool:
+        return True if self.domino else False
diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py
index 580bd99d..61878181 100644
--- a/src/nanotron/constants.py
+++ b/src/nanotron/constants.py
@@ -10,3 +10,7 @@
 
 CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
 MODEL_CONFIG_FILE_NAME = "model_config.json"
+
+
+### FOR COMMUNICATION ###
+CUDA_STREAM_COMM_NAME = "comm_stream_{}"
diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py
index 9cbf53f7..5234a0e7 100644
--- a/src/nanotron/helpers.py
+++ b/src/nanotron/helpers.py
@@ -518,6 +518,7 @@ def get_profiler(config: Config):
             record_shapes=config.profiler.record_shapes,
             profile_memory=config.profiler.profile_memory,
             with_stack=config.profiler.with_stack,
+            with_modules=True,
         )
     else:
         prof = contextlib.nullcontext()
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index db820644..4d7de8a6 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -34,9 +34,17 @@
 from nanotron.nn.activations import ACT2FN
 from nanotron.nn.layer_norm import TritonRMSNorm
 from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import CudaStreamManager, insert_backward_sync_to_tensor
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
 from nanotron.parallel.pipeline_parallel.p2p import P2P
+from nanotron.parallel.tensor_parallel.domino import (
+    BWD_ATTN_OP_NAME,
+    BWD_MLP_OP_NAME,
+    FWD_ATTN_OP_NAME,
+    FWD_MLP_OP_NAME,
+    OpNameContext,
+)
 from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
 from nanotron.parallel.tensor_parallel.nn import (
     TensorParallelColumnLinear,
@@ -211,6 +219,7 @@ def __init__(
         config: LlamaConfig,
         parallel_config: Optional[ParallelismArgs],
         tp_pg: dist.ProcessGroup,
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         super().__init__()
 
@@ -233,6 +242,7 @@ def __init__(
             async_communication=tp_linear_async_communication,
             contiguous_chunks=gate_up_contiguous_chunks,
             tp_recompute_allgather=parallel_config.tp_recompute_allgather,
+            stream_manager=stream_manager,
         )
         self.down_proj = TensorParallelRowLinear(
             config.intermediate_size,
@@ -241,10 +251,11 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
+            stream_manager=stream_manager,
         )
         self.split_silu_mul = GLUActivation(config.hidden_act)
 
-    def forward(self, hidden_states):  # [seq_length, batch_size, hidden_dim]
+    def forward(self, hidden_states: torch.Tensor):  # [seq_length, batch_size, hidden_dim]
         merged_states = self.gate_up_proj(hidden_states)
         hidden_states = self.down_proj(self.split_silu_mul(merged_states))
         return {"hidden_states": hidden_states}
@@ -339,6 +350,7 @@ def __init__(
         parallel_config: Optional[ParallelismArgs],
         tp_pg: dist.ProcessGroup,
         layer_idx: int,
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
 
@@ -394,6 +406,7 @@ def __init__(
             async_communication=tp_linear_async_communication,
             contiguous_chunks=qkv_contiguous_chunks,
             tp_recompute_allgather=parallel_config.tp_recompute_allgather,
+            stream_manager=stream_manager,
         )
         # TODO(kunhao): We want to have only one version per device and not one version per layer.
         if config.rope_interleaved:
@@ -422,6 +435,7 @@ def __init__(
             mode=tp_mode,
             bias=False,
             async_communication=tp_linear_async_communication,
+            stream_manager=stream_manager,
         )
 
         self.attention = CoreAttention(
@@ -717,46 +731,33 @@ def _forward_training(self, query_states, key_states, value_states, sequence_mas
         return {"hidden_states": output, "sequence_mask": sequence_mask}
 
 
-class LlamaDecoderLayer(nn.Module):
+class _BaseLlamaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
         parallel_config: Optional[ParallelismArgs],
         tp_pg: dist.ProcessGroup,
         layer_idx: int,
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         super().__init__()
         self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
         self.attn = CausalSelfAttention(
             config=config,
             parallel_config=parallel_config,
             tp_pg=tp_pg,
             layer_idx=layer_idx,
+            stream_manager=stream_manager,
         )
 
         self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-        self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
+        self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, stream_manager=stream_manager)
 
         self.recompute_layer = parallel_config.recompute_layer
-
-    def _core_forward(
-        self,
-        hidden_states: Union[torch.Tensor, TensorPointer],
-        sequence_mask: Union[torch.Tensor, TensorPointer],
-    ) -> List[Union[torch.Tensor, TensorPointer]]:
-        residual = hidden_states
-        hidden_states = self.input_layernorm(hidden_states)
-
-        output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
-        hidden_states = output["hidden_states"]
-        hidden_states = hidden_states + residual
-
-        residual = hidden_states
-        hidden_states = self.post_attention_layernorm(hidden_states)
-        hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
-        hidden_states = hidden_states + residual
-
-        return hidden_states, output["sequence_mask"]
+        self.parallel_config = parallel_config
+        self.layer_idx = layer_idx
+        self.stream_manager = stream_manager
 
     def _checkpointed_forward(
         self,
@@ -782,6 +783,142 @@ def forward(
         }
 
 
+class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.stream_manager is not None, "DominoLlamaDecoderLayer requires a stream_manager"
+
+    def _core_forward(
+        self,
+        hidden_states: Union[torch.Tensor, TensorPointer],
+        sequence_mask: Union[torch.Tensor, TensorPointer],
+    ) -> List[Union[torch.Tensor, TensorPointer]]:
+        num_input_batches = self.parallel_config.domino.num_input_batches
+
+        orig_sequence_mask = sequence_mask
+        comm_stream = self.stream_manager.get_default_comm_stream()
+        comm_bucket = self.stream_manager.comm_bucket
+
+        hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
+        sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)
+
+        hidden_states0, hidden_states1 = hidden_states
+        sequence_mask0, sequence_mask1 = sequence_mask
+
+        residual0 = hidden_states0
+        residual1 = hidden_states1
+
+        # TODO: overlap the 'layernorm > attn' of the second batch
+        # with the comm of the first batch in both forward and backward
+        hidden_states0 = self.input_layernorm(hidden_states0)
+        hidden_states1 = self.input_layernorm(hidden_states1)
+        hidden_states0 = insert_backward_sync_to_tensor(
+            hidden_states0,
+            BWD_ATTN_OP_NAME.format(self.layer_idx, 1),
+            self.stream_manager,
+        )
+        hidden_states1 = insert_backward_sync_to_tensor(
+            hidden_states1,
+            BWD_MLP_OP_NAME.format(self.layer_idx, 0),
+            self.stream_manager,
+        )
+
+        # TODO: maybe try to bucket all the communication as in DPP,
+        # do it at at once
+        with OpNameContext(
+            fwd_op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0),
+            bwd_op_name=BWD_ATTN_OP_NAME.format(self.layer_idx, 0),
+        ):
+            attn_output0 = self.attn(
+                hidden_states=hidden_states0,
+                sequence_mask=sequence_mask0,
+            )
+
+        with OpNameContext(
+            fwd_op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1),
+            bwd_op_name=BWD_ATTN_OP_NAME.format(self.layer_idx, 1),
+        ):
+            attn_output1 = self.attn(
+                hidden_states=hidden_states1,
+                sequence_mask=sequence_mask1,
+            )
+
+        # TODO(xrsrke): double check if we need this explicit synchronization
+        # otherwise, remove it
+        comm_stream.wait_stream(torch.cuda.default_stream())
+        with torch.cuda.stream(comm_stream):
+            comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 0))
+
+        torch.cuda.default_stream().wait_stream(comm_stream)
+
+        hidden_states0 = attn_output0["hidden_states"] + residual0
+        residual0 = hidden_states0
+        hidden_states0 = self.post_attention_layernorm(hidden_states0)
+        hidden_states0 = insert_backward_sync_to_tensor(
+            hidden_states0,
+            BWD_MLP_OP_NAME.format(self.layer_idx, 1),
+            self.stream_manager,
+        )
+
+        with OpNameContext(
+            fwd_op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0),
+            bwd_op_name=BWD_MLP_OP_NAME.format(self.layer_idx, 0),
+        ):
+            mlp_output0 = self.mlp(hidden_states=hidden_states0)
+
+        comm_stream.wait_stream(torch.cuda.default_stream())
+        with torch.cuda.stream(comm_stream):
+            comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 1))
+
+        torch.cuda.default_stream().wait_stream(comm_stream)
+
+        hidden_states1 = attn_output1["hidden_states"] + residual1
+        residual1 = hidden_states1
+        hidden_states1 = self.post_attention_layernorm(hidden_states1)
+
+        with OpNameContext(
+            fwd_op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1),
+            bwd_op_name=BWD_MLP_OP_NAME.format(self.layer_idx, 1),
+        ):
+            mlp_output1 = self.mlp(hidden_states=hidden_states1)
+
+        comm_stream.wait_stream(torch.cuda.default_stream())
+        with torch.cuda.stream(comm_stream):
+            comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 0))
+            comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 1))
+
+        torch.cuda.default_stream().wait_stream(comm_stream)
+
+        hidden_states0 = mlp_output0["hidden_states"] + residual0
+        hidden_states1 = mlp_output1["hidden_states"] + residual1
+
+        # TODO: make sure no memory overhead,
+        # and try a fixed memory buffer as in section 4.2 in the paper
+        hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
+        return hidden_states, orig_sequence_mask
+
+
+class LlamaDecoderLayer(_BaseLlamaDecoderLayer):
+    def _core_forward(
+        self,
+        hidden_states: Union[torch.Tensor, TensorPointer],
+        sequence_mask: Union[torch.Tensor, TensorPointer],
+    ) -> List[Union[torch.Tensor, TensorPointer]]:
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+
+        output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
+        hidden_states = output["hidden_states"]
+        hidden_states = hidden_states + residual
+
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
+        hidden_states = hidden_states + residual
+
+        return hidden_states, output["sequence_mask"]
+
+
 class Embedding(nn.Module, AttachableStore):
     def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
         super().__init__()
@@ -820,6 +957,7 @@ def __init__(
         config: LlamaConfig,
         parallel_context: ParallelContext,
         parallel_config: Optional[ParallelismArgs],
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         super().__init__()
 
@@ -856,12 +994,13 @@ def __init__(
             [
                 PipelineBlock(
                     p2p=self.p2p,
-                    module_builder=LlamaDecoderLayer,
+                    module_builder=DominoLlamaDecoderLayer if parallel_config.is_domino_enabled else LlamaDecoderLayer,
                     module_kwargs={
                         "config": config,
                         "parallel_config": parallel_config,
                         "tp_pg": parallel_context.tp_pg,
                         "layer_idx": layer_idx,
+                        "stream_manager": stream_manager,
                     },
                     module_input_keys={"hidden_states", "sequence_mask"},
                     module_output_keys={"hidden_states", "sequence_mask"},
@@ -924,6 +1063,7 @@ def forward_with_hidden_states(
             "hidden_states": output["input_embeds"],
             "sequence_mask": input_mask,
         }
+
         for encoder_block in self.decoder:
             hidden_encoder_states = encoder_block(**hidden_encoder_states)
 
@@ -1030,10 +1170,15 @@ def __init__(
         parallel_context: ParallelContext,
         parallel_config: Optional[ParallelismArgs],
         random_states: Optional[RandomStates] = None,
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         super().__init__()
-        self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
-
+        self.model = LlamaModel(
+            config=config,
+            parallel_context=parallel_context,
+            parallel_config=parallel_config,
+            stream_manager=stream_manager,
+        )
         # Choose the appropriate loss class based on config
         loss_kwargs = {
             "tp_pg": parallel_context.tp_pg,
@@ -1056,6 +1201,7 @@ def __init__(
         self.parallel_context = parallel_context
         self.config = config
         self.parallel_config = parallel_config
+        self.stream_manager = stream_manager
 
     def forward(
         self,
diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py
new file mode 100644
index 00000000..d112076a
--- /dev/null
+++ b/src/nanotron/parallel/comm.py
@@ -0,0 +1,132 @@
+from contextlib import contextmanager
+from typing import Dict
+
+import torch
+
+from nanotron.constants import CUDA_STREAM_COMM_NAME
+from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm
+
+
+class AsyncCommBucket:
+    """
+    Store aynchronous communication operations.
+    """
+
+    def __init__(self):
+        self._async_op: Dict[int, "dist.Work"] = {}
+        self._copy_async_op: Dict[int, "dist.Work"] = {}
+
+    def add(self, op_name: int, work: "dist.Work"):
+        assert op_name not in self._async_op, f"Operation with name: {op_name} already exists"
+        assert work is not None
+        self._async_op[op_name] = work
+        self._copy_async_op[op_name] = work
+
+    def get(self, op_name: str) -> "dist.Work":
+        if op_name not in self._async_op:
+            raise KeyError(f"Operation with name: {op_name} doesn't exist")
+
+        return self._async_op.get(op_name)
+
+    def pop(self, op_name: str) -> "dist.Work":
+        if op_name not in self._async_op:
+            raise KeyError(f"Operation with name: {op_name} doesn't exist")
+
+        return self._async_op.pop(op_name)
+
+    def wait(self, op_name: str):
+        """Wait and remove the operation from the bucket"""
+        work = self.pop(op_name)
+        work.wait()
+
+    def is_all_completed(self) -> bool:
+        if not len(self._async_op) == 0:
+            return False
+
+        not_finished = []
+        for k, v in self._copy_async_op.items():
+            if v.is_completed() is not True:
+                not_finished.append((k, v))
+        return len(not_finished) == 0
+
+    def clear_all(self):
+        self._async_op.clear()
+        self._copy_async_op.clear()
+
+
+class CudaStreamManager:
+    def __init__(self):
+        self._streams: Dict[str, "torch.cuda.Stream"] = {}
+        self.comm_bucket = AsyncCommBucket()
+
+    def init_default_comm_stream(self):
+        """
+        Initialize the default communication stream for the current cuda device.
+        """
+        self.create(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()), torch.cuda.current_device())
+
+    def create(self, name: str, device: torch.device):
+        assert name not in self._streams
+        self._streams[name] = torch.cuda.Stream(device=device)
+
+    def get(self, name: str):
+        if name not in self._streams:
+            self.create(name)
+        return self._streams.get(name)
+
+    def get_default_comm_stream(self) -> torch.cuda.Stream:
+        """
+        Return the default communication stream for the current cuda device.
+        """
+        return self.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()))
+
+    @contextmanager
+    def run_on_stream(self, name: str):
+        stream = self.get(name)
+        with torch.cuda.stream(stream):
+            yield stream
+
+
+class WaitComm(torch.autograd.Function):
+    """
+    Enforce a tensor to wait for the communication operation to finish
+    in torch's autograd graph.
+    """
+
+    @staticmethod
+    def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket):
+        assert isinstance(comm_stream, torch.cuda.Stream)
+        ctx.op_name = op_name
+        ctx.comm_stream = comm_stream
+        ctx.comm_bucket = comm_bucket
+        return input
+
+    @staticmethod
+    def backward(ctx, grad_output: torch.Tensor):
+        """
+        NOTE: because the communication operation is already being executed
+        so the communication stream don't have to wait for the compute stream here
+        but the compute stream waits for the communication stream
+        before proceeding
+        """
+        if is_domino_async_comm(ctx.op_name):
+            handle = ctx.comm_bucket.pop(ctx.op_name)
+            handle.wait()
+
+            ctx.comm_stream.synchronize()
+            torch.cuda.default_stream().wait_stream(ctx.comm_stream)
+
+        return grad_output, None, None, None
+
+
+def insert_backward_sync_to_tensor(
+    tensor: torch.Tensor, op_name: str, stream_manager: CudaStreamManager
+) -> torch.Tensor:
+    """
+    Insert a wait communication operation of a given op_name to the autograd graph
+    of a tensor.
+    """
+
+    assert isinstance(stream_manager, CudaStreamManager)
+    comm_stream = stream_manager.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()))
+    return WaitComm.apply(tensor, op_name, comm_stream, stream_manager.comm_bucket)
diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
index bd41347a..a5e65258 100644
--- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
+++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py
@@ -12,43 +12,84 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from contextlib import nullcontext
+from typing import Optional, Tuple
 
 import torch
 from torch import distributed as torch_dist
 
 from nanotron import distributed as dist
 from nanotron.distributed import ProcessGroup
+from nanotron.parallel.comm import CudaStreamManager, is_domino_async_comm
+from nanotron.parallel.tensor_parallel.domino import get_current_bwd_op_name
 
 
 class DifferentiableIdentity(torch.autograd.Function):
     """All-reduce gradients in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup]):
+    def forward(
+        ctx,
+        tensor: torch.Tensor,
+        group: Optional[ProcessGroup],
+        op_name: Optional[str] = None,
+        stream_manager: Optional[CudaStreamManager] = None,
+    ):
+        ctx.bwd_op_name = get_current_bwd_op_name()
         ctx.group = group
+        ctx.stream_manager = stream_manager
         return tensor
 
     @staticmethod
-    def backward(ctx, grad_output):
+    def backward(ctx, grad_output: torch.Tensor):
         group = ctx.group
-        return DifferentiableAllReduceSum.apply(grad_output, group), None
+
+        return (
+            DifferentiableAllReduceSum.apply(grad_output, group, ctx.bwd_op_name, ctx.stream_manager),
+            None,
+            None,
+            None,
+            None,
+        )
 
 
 class DifferentiableAllReduceSum(torch.autograd.Function):
     """All-reduce in a differentiable fashion"""
 
     @staticmethod
-    def forward(ctx, tensor, group: Optional[ProcessGroup]):
+    def forward(
+        ctx,
+        tensor: torch.Tensor,
+        group: Optional[ProcessGroup],
+        op_name: Optional[int] = None,
+        stream_manager: Optional[CudaStreamManager] = None,
+    ) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
+        async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False
+        ctx.async_all_reduce = async_all_reduce
+
         if group.size() == 1:
             return tensor
 
-        dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
+        if stream_manager is not None:
+            comm_stream = stream_manager.get_default_comm_stream()
+            comm_stream.wait_stream(torch.cuda.default_stream())
+            comm_context = torch.cuda.stream(comm_stream)
+        else:
+            comm_context = nullcontext()
+
+        with comm_context:
+            if async_all_reduce is True:
+                assert comm_stream is not None
+                handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
+                stream_manager.comm_bucket.add(op_name, handle)
+            else:
+                dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
+
         return tensor
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None
+        return grad_output, None, None, None, None
 
 
 class DifferentiableAllGather(torch.autograd.Function):
@@ -130,12 +171,22 @@ def backward(ctx, grad_output):
 # -----------------
 
 
-def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
-    return DifferentiableIdentity.apply(tensor, group)
+def differentiable_identity(
+    tensor,
+    group: Optional[ProcessGroup] = None,
+    op_name: Optional[str] = None,
+    stream_manager: Optional[CudaStreamManager] = None,
+):
+    return DifferentiableIdentity.apply(tensor, group, op_name, stream_manager)
 
 
-def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None):
-    return DifferentiableAllReduceSum.apply(tensor, group)
+def differentiable_all_reduce_sum(
+    tensor,
+    group: Optional[ProcessGroup] = None,
+    op_name: Optional[str] = None,
+    stream_manager: Optional[CudaStreamManager] = None,
+):
+    return DifferentiableAllReduceSum.apply(tensor, group, op_name, stream_manager)
 
 
 def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py
new file mode 100644
index 00000000..3d0b4415
--- /dev/null
+++ b/src/nanotron/parallel/tensor_parallel/domino.py
@@ -0,0 +1,73 @@
+"""
+Implementation of communication overlapping
+in the paper "Domino: Eliminating Communication in LLM Training via
+Generic Tensor Slicing and Overlapping"
+https://arxiv.org/abs/2409.15241
+"""
+
+import re
+import threading
+from typing import Optional
+
+FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}"
+FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}"
+BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}"
+BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}"
+
+_operation_context = threading.local()
+
+
+def is_domino_async_comm(x: str) -> bool:
+    """
+    Determine whether a module (e.g., mlp, attention)
+    runs all-reduce asynchronously in tensor parallelism
+    based on its module name.
+
+    Currently support intra-layer communication overlapping
+    as described in domino's input splitting approach.
+
+    How do we determine it?
+    + In the forward pass: We run all the forward pass's communication asynchronously
+    diagram: https://imgur.com/a/g5Ou2iZ
+
+    + In the backward pass: We run all backward pass's communication asynchronously
+    except for the first batch's attention module.
+    https://imgur.com/a/MrZb57a
+    """
+    NON_ASYNC_HANDLE_IDX = [
+        "bwd.layer_attn_{}_batch_0",
+    ]
+
+    patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX]  # Replace {} with regex for numbers
+    regex = re.compile("^(" + "|".join(patterns) + ")$")  # Combine patterns into a single regex
+    not_async = bool(regex.match(x))
+    return not not_async
+
+
+class OpNameContext:
+    """Track both forward and backward operation names"""
+
+    def __init__(self, fwd_op_name: str, bwd_op_name: str):
+        self.fwd_op_name = fwd_op_name
+        self.bwd_op_name = bwd_op_name
+        self.prev_fwd = None
+        self.prev_bwd = None
+
+    def __enter__(self):
+        self.prev_fwd = getattr(_operation_context, "current_fwd_op_name", None)
+        self.prev_bwd = getattr(_operation_context, "current_bwd_op_name", None)
+        _operation_context.current_fwd_op_name = self.fwd_op_name
+        _operation_context.current_bwd_op_name = self.bwd_op_name
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        _operation_context.current_fwd_op_name = self.prev_fwd
+        _operation_context.current_bwd_op_name = self.prev_bwd
+
+
+def get_current_fwd_op_name() -> Optional[str]:
+    return getattr(_operation_context, "current_fwd_op_name", None)
+
+
+def get_current_bwd_op_name() -> Optional[str]:
+    return getattr(_operation_context, "current_bwd_op_name", None)
diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py
index 722bc52f..9ed815a6 100644
--- a/src/nanotron/parallel/tensor_parallel/functional.py
+++ b/src/nanotron/parallel/tensor_parallel/functional.py
@@ -19,11 +19,13 @@
 from torch.nn import functional as F
 
 import nanotron.distributed as dist
+from nanotron.parallel.comm import CudaStreamManager
 from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
     differentiable_all_reduce_sum,
     differentiable_identity,
     differentiable_reduce_scatter_sum,
 )
+from nanotron.parallel.tensor_parallel.domino import get_current_fwd_op_name
 from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
 from nanotron.parallel.utils import MemoryBuffer
 
@@ -40,7 +42,7 @@ def forward(
         logits_max = torch.max(sharded_logits, dim=-1)[0]
         dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
         # Subtract the maximum value.
-        sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)
+        sharded_logits.sub_(logits_max.unsqueeze(dim=-1))
 
         # Get the shard's indices
         sharded_hidden_size = sharded_logits.shape[-1]
@@ -552,12 +554,14 @@ def column_linear(
     tp_mode: TensorParallelLinearMode,
     async_communication: bool,
     tp_recompute_allgather: bool = True,
+    stream_manager: Optional[CudaStreamManager] = None,
 ):
     if async_communication:
         return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        input = differentiable_identity(input, group=group)
+        op_name = get_current_fwd_op_name()
+        input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager)
         return F.linear(input, weight, bias)
     if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
@@ -702,6 +706,7 @@ def row_linear(
     group: dist.ProcessGroup,
     tp_mode: TensorParallelLinearMode,
     async_communication: bool,
+    stream_manager: Optional[CudaStreamManager] = None,
 ):
     if async_communication:
         return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
@@ -709,7 +714,8 @@ def row_linear(
     out = F.linear(input, weight, bias)
 
     if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
-        out = differentiable_all_reduce_sum(out, group=group)
+        op_name = get_current_fwd_op_name()
+        out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager)
     elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
         out = differentiable_reduce_scatter_sum(out, group=group)
     else:
diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py
index b738c7b3..28ce424a 100644
--- a/src/nanotron/parallel/tensor_parallel/nn.py
+++ b/src/nanotron/parallel/tensor_parallel/nn.py
@@ -21,6 +21,7 @@
 from nanotron import distributed as dist
 from nanotron.distributed import get_global_rank
 from nanotron.logging import get_logger
+from nanotron.parallel.comm import CudaStreamManager
 from nanotron.parallel.parameters import NanotronParameter
 from nanotron.parallel.sharded_parameters import (
     SplitConfig,
@@ -56,6 +57,7 @@ def __init__(
         async_communication: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
         tp_recompute_allgather: bool = True,
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         self.pg = pg
         self.world_size = pg.size()
@@ -76,6 +78,7 @@ def __init__(
 
         self.mode = mode
         self.async_communication = async_communication
+        self.stream_manager = stream_manager
 
         if self.world_size > 1:
             assert (
@@ -94,7 +97,10 @@ def __init__(
             split_config=split_config,
         )
 
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
         return column_linear(
             input=x,
             weight=self.weight,
@@ -103,6 +109,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             tp_mode=self.mode,
             async_communication=self.async_communication,
             tp_recompute_allgather=self.tp_recompute_allgather,
+            stream_manager=self.stream_manager,
         )
 
     def extra_repr(self) -> str:
@@ -121,6 +128,7 @@ def __init__(
         dtype=None,
         async_communication: bool = False,
         contiguous_chunks: Optional[Tuple[int, ...]] = None,
+        stream_manager: Optional[CudaStreamManager] = None,
     ):
         self.pg = pg
         self.world_size = pg.size()
@@ -129,6 +137,7 @@ def __init__(
 
         self.in_features = in_features // self.world_size
         self.out_features = out_features
+        self.stream_manager = stream_manager
 
         # No need to shard the bias term, only rank 0 would have it
         bias = dist.get_rank(self.pg) == 0 and bias
@@ -175,6 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             group=self.pg,
             tp_mode=self.mode,
             async_communication=self.async_communication,
+            stream_manager=self.stream_manager,
         )
 
     def extra_repr(self) -> str:
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index 4cb272c4..08de40ed 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -64,6 +64,7 @@
 from nanotron.models.starcoder2 import Starcoder2ForTraining
 from nanotron.optim.clip_grads import clip_grad_norm
 from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import CudaStreamManager
 from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
 from nanotron.parallel.parameters import NanotronParameter, sanity_check
 from nanotron.parallel.pipeline_parallel.engine import (
@@ -609,6 +610,7 @@ def training_step(
 
         self.post_train_step()
 
+        self.stream_manager.comm_bucket.clear_all()
         return outputs, loss_avg, z_loss_avg
 
     def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:
@@ -742,12 +744,16 @@ def _init_model_instance(self) -> NanotronModel:
             model_config_cls in CONFIG_TO_MODEL_CLASS
         ), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
 
+        self.stream_manager = CudaStreamManager()
+        self.stream_manager.init_default_comm_stream()
+
         model = self._init_model(
             model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls](
                 config=self.model_config,
                 parallel_context=self.parallel_context,
                 parallel_config=self.config.parallelism,
                 random_states=self.random_states,
+                stream_manager=self.stream_manager,
             ),
         )
         return model
diff --git a/tests/helpers/llama_helper.py b/tests/helpers/llama_helper.py
index 7334f857..13c66658 100644
--- a/tests/helpers/llama_helper.py
+++ b/tests/helpers/llama_helper.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 import torch
 from nanotron.config import (
     AdamWOptimizerArgs,
@@ -20,6 +22,7 @@
 from nanotron.config.config import PretrainDatasetsArgs
 from nanotron.models import build_model
 from nanotron.models.llama import LlamaForTraining
+from nanotron.parallel.comm import CudaStreamManager
 from nanotron.parallel.context import ParallelContext
 from nanotron.trainer import mark_tied_parameters
 
@@ -47,15 +50,21 @@
 )
 
 
-def get_llama_training_config(model_config: ModelArgs):
-    return Config(
-        model=model_config,
-        general=GeneralArgs(project="unittest", run="sanity_llama", seed=42),
-        checkpoints=CheckpointsArgs(
-            checkpoints_path="./checkpoints",
-            checkpoint_interval=10,
-        ),
-        parallelism=ParallelismArgs(
+def get_parallel_config(parallel_context: ParallelContext):
+    return ParallelismArgs(
+        dp=parallel_context.data_parallel_size,
+        pp=parallel_context.pipeline_parallel_size,
+        tp=parallel_context.tensor_parallel_size,
+        expert_parallel_size=parallel_context.expert_parallel_size,
+        pp_engine=AllForwardAllBackwardPipelineEngine(),
+        tp_mode=TensorParallelLinearMode.ALL_REDUCE,
+        tp_linear_async_communication=False,
+    )
+
+
+def get_llama_training_config(model_config: ModelArgs, parallel_context: ParallelContext):
+    if parallel_context is None:
+        parallel_config = ParallelismArgs(
             dp=1,
             pp=1,
             tp=2,
@@ -63,7 +72,18 @@ def get_llama_training_config(model_config: ModelArgs):
             pp_engine="1f1b",
             tp_mode="ALL_REDUCE",
             tp_linear_async_communication=False,
+        )
+    else:
+        parallel_config = get_parallel_config(parallel_context)
+
+    return Config(
+        model=model_config,
+        general=GeneralArgs(project="unittest", run="sanity_llama", seed=42),
+        checkpoints=CheckpointsArgs(
+            checkpoints_path="./checkpoints",
+            checkpoint_interval=10,
         ),
+        parallelism=parallel_config,
         tokenizer=TokenizerArgs("gpt2"),
         optimizer=OptimizerArgs(
             optimizer_factory=AdamWOptimizerArgs(
@@ -106,7 +126,11 @@ def get_llama_training_config(model_config: ModelArgs):
 
 
 def create_llama_from_config(
-    model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext
+    model_config: LlamaConfig,
+    device: torch.device,
+    parallel_context: ParallelContext,
+    parallel_config: Optional[ParallelismArgs] = None,
+    stream_manager: Optional[CudaStreamManager] = None,
 ) -> LlamaForTraining:
 
     """
@@ -117,20 +141,25 @@ def create_llama_from_config(
     the model created will have random weights.
     """
 
-    parallel_config = ParallelismArgs(
-        dp=parallel_context.data_parallel_size,
-        pp=parallel_context.pipeline_parallel_size,
-        tp=parallel_context.tensor_parallel_size,
-        pp_engine=AllForwardAllBackwardPipelineEngine(),
-        tp_mode=TensorParallelLinearMode.ALL_REDUCE,
-        tp_linear_async_communication=False,
-    )
+    if parallel_config is None:
+        parallel_config = ParallelismArgs(
+            dp=parallel_context.data_parallel_size,
+            pp=parallel_context.pipeline_parallel_size,
+            tp=parallel_context.tensor_parallel_size,
+            pp_engine=AllForwardAllBackwardPipelineEngine(),
+            tp_mode=TensorParallelLinearMode.ALL_REDUCE,
+            tp_linear_async_communication=False,
+        )
+    else:
+        parallel_config = parallel_config
+
     model = build_model(
         model_builder=lambda: LlamaForTraining(
             config=model_config,
             parallel_context=parallel_context,
             parallel_config=parallel_config,
             random_states=None,
+            stream_manager=stream_manager,
         ),
         parallel_context=parallel_context,
         dtype=torch.bfloat16,
diff --git a/tests/test_base_model.py b/tests/test_base_model.py
index b3474bc2..81cdd83e 100644
--- a/tests/test_base_model.py
+++ b/tests/test_base_model.py
@@ -2,7 +2,7 @@
 import torch
 import torch.distributed as dist
 from helpers.utils import init_distributed, rerun_if_address_is_in_use
-from nanotron.config import Config, ModelArgs, RandomInit
+from nanotron.config import ModelArgs, RandomInit
 from nanotron.parallel import ParallelContext
 from nanotron.parallel.pipeline_parallel.block import PipelineBlock
 from torch import nn
@@ -11,21 +11,21 @@
 
 
 @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)])
-@pytest.mark.skip
 @rerun_if_address_is_in_use()
 def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int):
     model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
-    config = get_llama_training_config(model_args)
 
-    init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(config=config)
+    init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(model_args=model_args)
 
 
 def _test_get_named_modules_in_pp_rank(
     parallel_context: ParallelContext,
-    config: Config,
+    model_args: ModelArgs,
 ):
+    config = get_llama_training_config(model_args, parallel_context)
     model = create_llama_from_config(
         model_config=config.model.model_config,
+        parallel_config=config.parallelism,
         device=torch.device("cuda"),
         parallel_context=parallel_context,
     )
@@ -44,3 +44,37 @@ def _test_get_named_modules_in_pp_rank(
         # not PipelineBlock
         assert isinstance(module, nn.Module)
         assert name not in modules_that_not_in_current_pp_rank
+
+
+@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 1)])
+@rerun_if_address_is_in_use()
+def test_llama_model(tp: int, dp: int, pp: int):
+    BATCH_SIZE, SEQ_LEN = 10, 128
+    model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
+
+    init_distributed(tp=tp, dp=dp, pp=pp)(_test_llama_model)(
+        model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN
+    )
+
+
+def _test_llama_model(
+    parallel_context: ParallelContext,
+    model_args: ModelArgs,
+    batch_size: int,
+    seq_len: int,
+):
+    config = get_llama_training_config(model_args, parallel_context)
+    llama_model = create_llama_from_config(
+        model_config=config.model.model_config,
+        parallel_config=config.parallelism,
+        device=torch.device("cuda"),
+        parallel_context=parallel_context,
+    )
+    llama_model.init_model_randomly(config=config)
+
+    input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda")
+    input_mask = torch.ones_like(input_ids)
+    outputs = llama_model(input_ids, input_mask, input_mask, input_mask)
+
+    assert list(outputs.keys()) == ["loss"]
+    assert isinstance(outputs["loss"], torch.Tensor)
diff --git a/tests/test_comm.py b/tests/test_comm.py
new file mode 100644
index 00000000..39514a38
--- /dev/null
+++ b/tests/test_comm.py
@@ -0,0 +1,127 @@
+import pytest
+import torch
+import torch.distributed as dist
+from helpers.utils import (
+    init_distributed,
+    rerun_if_address_is_in_use,
+)
+from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager, insert_backward_sync_to_tensor
+
+
+class MockWork:
+    def __init__(self):
+        self.completed = False
+        self.wait_called = False
+
+    def wait(self):
+        self.wait_called = True
+        self.completed = True
+
+    def is_completed(self):
+        return self.completed
+
+
+def test_cuda_stream_manager():
+    manager = CudaStreamManager()
+    manager.create("test", torch.device("cuda"))
+
+    stream = manager.get("test")
+    assert isinstance(stream, torch.cuda.Stream)
+
+
+@rerun_if_address_is_in_use()
+def test_add_async_op_to_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)()
+
+
+def _test_add_async_op_to_bucket(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    tensor = torch.randn(1, device="cuda")
+    work = dist.all_reduce(tensor, async_op=True)
+
+    comm_bucket = AsyncCommBucket()
+    comm_bucket.add(OP_NAME, work)
+
+    assert comm_bucket.get(OP_NAME) is work
+
+
+@rerun_if_address_is_in_use()
+def test_wait_async_op_to_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)()
+
+
+def _test_wait_async_op_to_bucket(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    work = MockWork()
+    comm_bucket = AsyncCommBucket()
+
+    comm_bucket.add(OP_NAME, work)
+    assert work.is_completed() is False
+
+    comm_bucket.wait(OP_NAME)
+    assert work.is_completed()
+    with pytest.raises(KeyError):
+        comm_bucket.get(OP_NAME)
+
+
+@rerun_if_address_is_in_use()
+def test_is_all_completed_in_async_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)()
+
+
+def _test_wait_async_op_to_bucket(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    work = MockWork()
+    comm_bucket = AsyncCommBucket()
+
+    comm_bucket.add(OP_NAME, work)
+    assert comm_bucket.is_all_completed() is False
+
+    comm_bucket.wait(OP_NAME)
+    assert comm_bucket.is_all_completed() is True
+
+
+@rerun_if_address_is_in_use()
+def test_clear_ops_in_async_bucket():
+    init_distributed(tp=2, dp=1, pp=1)(_test_clear_ops_in_async_bucket)()
+
+
+def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext):
+    comm_bucket = AsyncCommBucket()
+
+    comm_bucket.add("test1", MockWork())
+    comm_bucket.add("test2", MockWork())
+    comm_bucket.add("test3", MockWork())
+
+    assert comm_bucket.is_all_completed() is False
+
+    comm_bucket.clear_all()
+    assert comm_bucket.is_all_completed() is True
+    with pytest.raises(KeyError):
+        comm_bucket.get("test1")
+
+
+@rerun_if_address_is_in_use()
+def test_wait_comm():
+    init_distributed(tp=2, dp=1, pp=1)(_test_wait_comm)()
+
+
+def _test_wait_comm(parallel_context: ParallelContext):
+    OP_NAME = "test"
+    tensor = torch.randn(1, device="cuda", requires_grad=True)
+    stream_manager = CudaStreamManager()
+
+    comm_stream = torch.cuda.Stream()
+
+    with torch.cuda.stream(comm_stream):
+        work = MockWork()
+        stream_manager.comm_bucket.add(OP_NAME, work)
+
+    output = insert_backward_sync_to_tensor(tensor, OP_NAME, stream_manager)
+    assert work.is_completed() is False
+
+    # NOTE: we test that it waits for the async op to complete
+    # automatically in autograd
+    (output + 1).backward()
+    assert work.is_completed()
diff --git a/tests/test_domino.py b/tests/test_domino.py
new file mode 100644
index 00000000..843437a0
--- /dev/null
+++ b/tests/test_domino.py
@@ -0,0 +1,126 @@
+import os
+from copy import deepcopy
+
+import pytest
+import torch
+from helpers.llama_helper import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config
+from helpers.utils import init_distributed, rerun_if_address_is_in_use
+from nanotron.config import ModelArgs, RandomInit
+from nanotron.config.parallelism_config import DominoArgs
+from nanotron.models.llama import DominoLlamaDecoderLayer
+from nanotron.parallel import ParallelContext
+from nanotron.parallel.comm import CudaStreamManager
+from nanotron.parallel.tensor_parallel.domino import (
+    OpNameContext,
+    get_current_bwd_op_name,
+    get_current_fwd_op_name,
+    is_domino_async_comm,
+)
+
+
+@pytest.mark.parametrize(
+    "op_name, expected",
+    [
+        ("fwd.layer_attn_1_batch_0", True),
+        ("fwd.layer_attn_1_batch_1", True),
+        ("fwd.layer_mlp_1_batch_0", True),
+        ("fwd.layer_mlp_1_batch_1", True),
+        ("bwd.layer_mlp_1_batch_1", True),
+        ("bwd.layer_mlp_1_batch_0", True),
+        ("bwd.layer_attn_1_batch_1", True),
+        ("bwd.layer_attn_1_batch_0", False),
+    ],
+)
+def test_is_domino_async_comm(op_name, expected):
+    assert is_domino_async_comm(op_name) == expected
+
+
+@pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)])
+@rerun_if_address_is_in_use()
+def test_domino_model(tp: int, dp: int, pp: int):
+    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
+    BATCH_SIZE, SEQ_LEN = 10, 128
+
+    model_config = deepcopy(TINY_LLAMA_CONFIG)
+    model_config.num_hidden_layers = 28
+    model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
+
+    init_distributed(tp=tp, dp=dp, pp=pp)(_test_domino_model)(
+        model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN
+    )
+
+
+def _test_domino_model(
+    parallel_context: ParallelContext,
+    model_args: ModelArgs,
+    batch_size: int,
+    seq_len: int,
+):
+    config = get_llama_training_config(model_args, parallel_context)
+    config.parallelism.domino = DominoArgs(num_input_batches=2)
+    stream_manager = CudaStreamManager()
+    stream_manager.init_default_comm_stream()
+
+    llama_model = create_llama_from_config(
+        model_config=config.model.model_config,
+        parallel_config=config.parallelism,
+        device=torch.device("cuda"),
+        parallel_context=parallel_context,
+        stream_manager=stream_manager,
+    )
+    llama_model.init_model_randomly(config=config)
+
+    for m in llama_model.model.decoder:
+        assert isinstance(m.pp_block, DominoLlamaDecoderLayer)
+
+    input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda")
+    input_mask = torch.ones_like(input_ids)
+    outputs = llama_model(input_ids, input_mask, input_mask, input_mask)
+
+    assert isinstance(outputs["loss"], torch.Tensor)
+    assert stream_manager.comm_bucket.is_all_completed() is True
+
+
+### OpNameContext tests ###
+
+
+def test_op_name_context_reentry():
+    assert get_current_fwd_op_name() is None
+    assert get_current_bwd_op_name() is None
+    context = OpNameContext(fwd_op_name="fwd.reusable_op", bwd_op_name="bwd.reusable_op")
+
+    with context:
+        assert get_current_fwd_op_name() == "fwd.reusable_op"
+        assert get_current_bwd_op_name() == "bwd.reusable_op"
+
+    assert get_current_fwd_op_name() is None
+    assert get_current_bwd_op_name() is None
+
+    with context:
+        assert get_current_fwd_op_name() == "fwd.reusable_op"
+        assert get_current_bwd_op_name() == "bwd.reusable_op"
+
+    assert get_current_fwd_op_name() is None
+    assert get_current_bwd_op_name() is None
+
+
+def test_deeply_nested_contexts():
+    with OpNameContext(fwd_op_name="fwd.level1", bwd_op_name="fwd.level1"):
+        assert get_current_fwd_op_name() == "fwd.level1"
+
+        with OpNameContext(fwd_op_name="fwd.level2", bwd_op_name="fwd.level2"):
+            assert get_current_fwd_op_name() == "fwd.level2"
+
+        assert get_current_fwd_op_name() == "fwd.level1"
+
+
+def test_multiple_sequential_contexts():
+    assert get_current_fwd_op_name() is None
+
+    with OpNameContext(fwd_op_name="fwd.first_op", bwd_op_name="bwd.first_op"):
+        assert get_current_fwd_op_name() == "fwd.first_op"
+
+    with OpNameContext(fwd_op_name="fwd.second_op", bwd_op_name="bwd.second_op"):
+        assert get_current_fwd_op_name() == "fwd.second_op"
+
+    assert get_current_fwd_op_name() is None