diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index f4e7628c..5e372825 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -183,6 +183,81 @@ def get_model_config(model_variant): "pad_vocab_size_multiple": 16, "tie_embeddings": False, } + elif model_variant == "mamba_30b_moe": + model_config = { + "d_model": 3072, + "d_intermediate": 1344, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 24, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "mlp_cfg": {"n_expert": 64, "load_balancing_loss": True, "top_k": 8}, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } + elif model_variant == "mamba_120b_moe": + model_config = { + "d_model": 4096, + "d_intermediate": 896, + "n_layer": 40, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27, 36], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "mlp_cfg": {"n_expert": 256, "load_balancing_loss": True, "top_k": 16}, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } + elif model_variant == "mamba_236b_moe": + model_config = { + "d_model": 5120, + "d_intermediate": 1536, + "n_layer": 60, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27, 36, 45, 54], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 40, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "mlp_cfg": {"n_expert": 160, "load_balancing_loss": True, "top_k": 8}, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } else: raise ValueError(f"model variant {model_variant} not supported.") diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index ef421f6f..4c1ff5d6 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -2,6 +2,8 @@ from dataclasses import asdict from functools import partial +from transformers.models.granitemoe.modeling_granitemoe import load_balancing_loss_func + try: import packaging.version @@ -73,7 +75,7 @@ def train( run["hparams"] = asdict(cfg) model.train() - ddp_stats = torch.zeros(3).to(local_rank) + ddp_stats = torch.zeros(4).to(local_rank) start = time.time() loop_start = time.time() @@ -86,9 +88,20 @@ def train( optimizer.zero_grad() output = model(input) - output = output.logits if hasattr(output, "logits") else output + logits = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() - loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) + loss = ce_loss(logits.view(-1, logits.size(-1)), label.view(-1).long()) + ddp_stats[3] += loss.item() + if "moe" in cfg.model_variant: + aux_outputs = output.aux_outputs + if aux_outputs is not None: + top_k = model.config.mlp_cfg.get("top_k", 2) + aux_loss = load_balancing_loss_func( + aux_outputs, + num_experts=model.config.mlp_cfg["n_expert"], + top_k=top_k, + ) + loss += 0.2 * aux_loss loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() @@ -105,6 +118,7 @@ def train( dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) train_loss = ddp_stats[0] / ddp_stats[2] g_norm = ddp_stats[1] / ddp_stats[2] + original_loss = ddp_stats[3] / ddp_stats[2] elapsed_time = time.time() - loop_start world_size = int(os.environ["WORLD_SIZE"]) new_tokens_seen = ( @@ -113,6 +127,7 @@ def train( if rank == 0: total_tokens_seen = tokens_seen + new_tokens_seen current_loss = train_loss.item() + current_original_loss = original_loss.item() current_lr = scheduler.get_last_lr()[0] current_gnorm = g_norm.item() current_step_time = (time.time() - start) / cfg.report_interval @@ -132,6 +147,7 @@ def train( print("step:", batch_idx) print("loss:", current_loss) + print("original loss:", current_original_loss) print("LR:", current_lr) print("tokens seen:", total_tokens_seen) print("gradient norm:", current_gnorm) @@ -149,6 +165,7 @@ def train( vals_to_track = { "learning rate": current_lr, "loss": current_loss, + "original_loss": current_original_loss, "gradient norm": current_gnorm, "token seen": total_tokens_seen, "current throughput (token per gpu per sec)": current_throughput, diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..87b76ff0 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -64,7 +64,8 @@ def main(**kwargs): # get model config_data = get_model_config(cfg.model_variant) mamba_config = MambaConfig(**config_data) - model = MambaLMHeadModel(mamba_config) + with torch.device("meta"): + model = MambaLMHeadModel(mamba_config) if rank == 0: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -89,7 +90,7 @@ def main(**kwargs): use_orig_params=cfg.use_torch_compile, device_id=torch.cuda.current_device(), limit_all_gathers=True, - param_init_fn=param_init_fn, + param_init_fn=lambda x: x.to_empty(device=torch.cuda.current_device(), recurse=False), ) # fsdp activation checkpointing