diff --git a/configs/eleutherai_cluster.yml b/configs/eleutherai_cluster.yml index 36e75d8b3..3cf5bb007 100644 --- a/configs/eleutherai_cluster.yml +++ b/configs/eleutherai_cluster.yml @@ -24,6 +24,7 @@ "tensorboard_dir": "/mnt/ssd-1/tensorboard", "log_dir": "/mnt/ssd-1/logs", "wandb_team": "eleutherai", + #"wandb_run_name": "experiment" "wandb_project": "neox", "wandb_group": "example" } diff --git a/megatron/logging.py b/megatron/logging.py index af8a41fe5..48481c047 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -80,8 +80,12 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def get_flops(neox_args, iter_time_s) -> float: +def get_actual_flops(neox_args, iter_time_s) -> float: """ + This function finds the actual FLOPs achieved accounting for implementation and hardware details. Also used for HFU. + + For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc + Use FLOPS calculation from Megatron-DeepSpeed: https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 They get it from https://arxiv.org/pdf/2104.04473.pdf @@ -156,6 +160,83 @@ def get_flops(neox_args, iter_time_s) -> float: return flops_per_iteration / (iter_time_s * world_size) +def get_forward_backward_flops(neox_args, iter_time_s) -> float: + """ + This function finds the estimated FLOPs required by a single forward+backward pass without accounting for implementation and hardware details. Also used for MFU. + + Mostly duplicated from get_actual_flops with just a change in activation checkpointing for now, but these may diverge over time as implementation details accumulate so I think 2 separate functions are appropriate. + + For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc + + Use FLOPS calculation from Megatron-DeepSpeed: + https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 + They get it from https://arxiv.org/pdf/2104.04473.pdf + """ + world_size = torch.distributed.get_world_size() + vocab_size = neox_args.padded_vocab_size + batch_size = neox_args.train_batch_size + seq_len = neox_args.seq_length + hidden_size = neox_args.hidden_size + num_layers = neox_args.num_layers + fwd_bwd_factor = 3 # 1 for fwd, 2 for bwd and weight update + if "rwkv" in neox_args.attention_config: + num_heads = neox_args.num_attention_heads + + flops_per_iteration = ( + batch_size + * seq_len + * ( + 78 * hidden_size * hidden_size * num_layers + + 84 * hidden_size * num_layers + + 16 * hidden_size + + 12 * hidden_size * vocab_size + + 18 * hidden_size * hidden_size * num_layers / num_heads + ) + ) + elif "mamba" in neox_args.attention_config: + # from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py + if neox_args.expansion_factor: + d_inner = neox_args.hidden_size * neox_args.expansion_factor + elif neox_args.intermediate_size: + d_inner = neox_args.intermediate_size + else: + d_inner = neox_args.hidden_size * 2 # default expansion factor + d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here + conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here + dt_rank = math.ceil(neox_args.hidden_size / 16) + ssm_flops = ( + fwd_bwd_factor + * d_inner + * seq_len + * batch_size + * (11 * d_state + 4 * dt_rank + 1) + ) + mamba_projectors_flops = ( + fwd_bwd_factor * seq_len * batch_size * 6 * d_inner * hidden_size + ) + mamba_conv_flops = ( + fwd_bwd_factor * seq_len * batch_size * 2 * d_inner * conv_dimension + ) + mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops + embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size + flops_per_iteration = mamba_flops * num_layers + embedding_flops + else: + flops_per_iteration = ( + 24 + * fwd_bwd_factor + * batch_size + * seq_len + * num_layers + * (hidden_size**2) + * ( + 1.0 + + (seq_len / (6.0 * hidden_size)) + + (vocab_size / (16.0 * num_layers * hidden_size)) + ) + ) + return flops_per_iteration / (iter_time_s * world_size) + + def training_log( neox_args, timers, @@ -350,6 +431,8 @@ def add_to_logging(name): elapsed_time = timers("interval time").elapsed() iteration_time = elapsed_time / neox_args.log_interval samples_per_sec = neox_args.train_batch_size / iteration_time + steps_per_sec = 1 / iteration_time + tokens_per_sec = samples_per_sec * neox_args.seq_length log_string = " samples/sec: {:.3f} |".format(samples_per_sec) tb_wandb_log( "runtime/samples_per_sec", @@ -367,6 +450,22 @@ def add_to_logging(name): tensorboard_writer=neox_args.tensorboard_writer, comet_experiment=neox_args.comet_experiment, ) + tb_wandb_log( + "runtime/steps_per_sec", + steps_per_sec, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) + tb_wandb_log( + "runtime/tokens_per_sec", + tokens_per_sec, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) log_string += " iteration {:8d}/{:8d} |".format( iteration, neox_args.train_iters ) @@ -390,7 +489,7 @@ def add_to_logging(name): ) # log tflop / gpu - flops_per_s_per_gpu = get_flops(neox_args, iteration_time) + flops_per_s_per_gpu = get_actual_flops(neox_args, iteration_time) log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" @@ -404,6 +503,39 @@ def add_to_logging(name): comet_experiment=neox_args.comet_experiment, ) + if neox_args.peak_theoretical_tflops: + # Convert peak theoretical TFLOPS to FLOPS for consistent units + peak_theoretical_flops = neox_args.peak_theoretical_tflops * (10**12) + + # Calculate MFU and HFU as percentages + mfu = ( + get_forward_backward_flops(neox_args, iteration_time) + / peak_theoretical_flops + ) * 100 + hfu = (flops_per_s_per_gpu / peak_theoretical_flops) * 100 + + # Add to log string + log_string += f" MFU: {mfu:.2f}% | HFU: {hfu:.2f}% |" + + # Log to tracking systems + tb_wandb_log( + "runtime/model_flops_utilization", + mfu, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) + + tb_wandb_log( + "runtime/hardware_flops_utilization", + hfu, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) + for key in total_loss_dict: if key not in [skipped_iters_key, got_nan_key]: v = ( diff --git a/megatron/neox_arguments/__init__.py b/megatron/neox_arguments/__init__.py index 025464cbf..087dbe6b7 100644 --- a/megatron/neox_arguments/__init__.py +++ b/megatron/neox_arguments/__init__.py @@ -18,7 +18,7 @@ * NeoXArgs.from_ymls(["path_to_yaml1", "path_to_yaml2", ...]): load yaml configuration files and instantiate with the values provided; checks for duplications and unknown arguments are performed * NeoXArgs.from_dict({"num_layers": 12, ...}): load attribute values from dict; checks unknown arguments are performed -* NeoXArgs.consume_deepy_args(): entry point for deepy.py configuring and consuming command line arguments (i.e. user_script, conf_dir, conf_file, wandb_group, wandb_team); neox_args.get_deepspeed_main_args() produces a list of command line arguments to feed to deepspeed.launcher.runner.main +* NeoXArgs.consume_deepy_args(): entry point for deepy.py configuring and consuming command line arguments (i.e. user_script, conf_dir, conf_file, wandb_group, wandb_run_name, wandb_team); neox_args.get_deepspeed_main_args() produces a list of command line arguments to feed to deepspeed.launcher.runner.main * NeoXArgs.consume_neox_args(): In the call stack deepy.py -> deepspeed -> pretrain_gpt2.py; arguments are passed to pretrain_gpt2.py by neox_args.get_deepspeed_main_args(). So produced arguments can be read with consume_neox_args() to instantiate a NeoXArgs instance. diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 1e5567c80..f3daacd4d 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -335,6 +335,12 @@ def consume_deepy_args(cls, input_args=None): default=None, help='Weights & Biases group name - used to group together "runs".', ) + group.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="Weights & Biases run name for the current experiment.", + ) group.add_argument( "--wandb_team", type=str, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index c877c6c78..59e9f93ea 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -624,12 +624,16 @@ class NeoXArgsLogging(NeoXArgsTemplate): Logging Arguments """ + ### BEGIN WANDB ARGS ### use_wandb: bool = None """Flag indicating if wandb is to be used.""" wandb_group: str = None """Weights and Biases group name - used to group together "runs".""" + wandb_run_name: str = None + """Weights and Biases run name for the current experiment""" + wandb_team: str = None """Team name for Weights and Biases.""" @@ -641,6 +645,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): wandb_init_all_ranks: bool = False """Initialize wandb on all ranks.""" + ### END WANDB ARGS ### git_hash: str = get_git_commit_hash() """current git hash of repository""" @@ -650,6 +655,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): Directory to save logs to. """ + ### BEGIN TENSORBOARD ARGS ### tensorboard_writer = None """ initialized tensorboard writer @@ -659,7 +665,9 @@ class NeoXArgsLogging(NeoXArgsTemplate): """ Write TensorBoard logs to this directory. """ + ### END TENSORBOARD ARGS ### + ### BEGIN COMET ARGS ### use_comet: bool = None """Flag indicating if comet is to be used.""" @@ -692,6 +700,12 @@ class NeoXArgsLogging(NeoXArgsTemplate): """ Initialized comet experiment object used to log data """ + ### END COMET ARGS ### + + peak_theoretical_tflops: float = None + """ + The peak hardware flops with which to compute MFU and HFU, in units of teraflops. Automatic detection is more trouble than it's worth, so this is left to the user. Helpful table listed at https://github.com/stas00/ml-engineering/tree/master/compute/accelerator#tflops-comparison-table + """ log_interval: int = 100 """ @@ -711,8 +725,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): log_grad_norm: bool = False """ Log the frob norm of the gradients to wandb / tensorboard (useful for debugging). - (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because - deepspeed.) + (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.) """ log_optimizer_states: bool = False @@ -735,6 +748,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): Whether to offload the buffered gradients to cpu when measuring gradient noise scale. """ + ### BEGIN PROFILING ARGS memory_profiling: bool = False """ Whether to take a memory snapshot of the model. Useful for debugging memory issues. @@ -767,6 +781,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): """ Step to stop profiling at. """ + ### END PROFILING ARGS ### @dataclass diff --git a/megatron/utils.py b/megatron/utils.py index 507c44179..fc2f80dad 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -166,12 +166,12 @@ def init_wandb(neox_args): neox_args.update_value("use_wandb", use_wandb) if neox_args.use_wandb: group_name = neox_args.wandb_group - name = f"{socket.gethostname()}-{local_rank()}" if group_name else None + run_name = neox_args.wandb_run_name try: wandb.init( project=neox_args.wandb_project, group=group_name, - name=name, + name=run_name, save_code=False, force=False, entity=neox_args.wandb_team,