Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Domino #279

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions examples/config_llama_domino.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron_domino
run: config_llama_domino
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 16384
is_llama_config: true
max_position_embeddings: 4096
num_attention_heads: 32
num_hidden_layers: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 1000
lr_decay_style: cosine
lr_warmup_steps: 500
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
pp: 1
tp: 8
expert_parallel_size: 1
pp_engine: 1f1b
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
domino:
num_input_batches: 2
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 4096
train_steps: 1500
val_check_interval: -1
25 changes: 25 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode


@dataclass
class DominoArgs:
"""
Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping
https://arxiv.org/abs/2409.15241
"""

# NOTE: if the number of input batches is 1,
# it's equivalent to non-domino mode
# so if you want to enable domino mode, set this to > 1
num_input_batches: int

def __post_init__(self):
assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1"
assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino"


@dataclass
class ParallelismArgs:
"""Arguments related to TP/PP/DP
Expand All @@ -37,6 +54,7 @@ class ParallelismArgs:
tp_recompute_allgather: bool = True

expert_parallel_size: int = 1
domino: Optional[DominoArgs] = None

def __post_init__(self):
# Conservative defaults
Expand All @@ -51,3 +69,10 @@ def __post_init__(self):
self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
if isinstance(self.tp_mode, str):
self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]

if self.is_domino_enabled is True:
assert self.tp > 1, "Domino requires TP > 1"

@property
def is_domino_enabled(self) -> bool:
return True if self.domino else False
2 changes: 2 additions & 0 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def get_profiler(config: Config):
on_trace_ready=on_trace_ready,
# record_shapes=True,
# profile_memory=True,
with_flops=True,
with_stack=True,
with_modules=True,
)
else:
prof = contextlib.nullcontext()
Expand Down
Loading
Loading