Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HuggingFace arg so that arch is automatic #39

Merged
merged 12 commits into from
Aug 19, 2024
10 changes: 6 additions & 4 deletions calc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ Example with pythia 6.9B: python calc_transformer_mem.py --num-layers=32 --seque
Example with pythia 12B: python calc_transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256
Example with default 20B: python calc_transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1

usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE]
[--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE]
[--num-attention-heads NUM_ATTENTION_HEADS] [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--num-mlp-linears NUM_MLP_LINEARS] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--output-tokens OUTPUT_TOKENS]
[--disable-mixed-precision] [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS]
usage: calc_transformer_mem.py [-h] [--hf_model_name_or_path HF_MODEL_NAME_OR_PATH] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}]
[--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE]
[--hidden-size HIDDEN_SIZE] [--num-attention-heads NUM_ATTENTION_HEADS] [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--num-mlp-linears NUM_MLP_LINEARS] [--infer] [--kv-size-ratio KV_SIZE_RATIO]
[--output-tokens OUTPUT_TOKENS] [--disable-mixed-precision] [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS]
[--expert-parallelism EXPERT_PARALLELISM] [--misc-mem-gib MISC_MEM_GIB]

options:
-h, --help show this help message and exit
--hf_model_name_or_path HF_MODEL_NAME_OR_PATH
Name of the HuggingFace Hub repository or the local file path for it
--num-gpus NUM_GPUS Number of GPUs used for training
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
Tensor parallel degree (1 if not used)
Expand Down
176 changes: 151 additions & 25 deletions calc/calc_transformer_mem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# By Quentin Anthony and Hailey Schoelkopf
# By Quentin Anthony, Hailey Schoelkopf, Bhavnick Minhas

import argparse
import math

# Helper function to pretty-print message sizes

### Begin Helper Functions ###

def convert_params(params):
'''
Helper function to pretty-print message sizes
'''
if params == 0:
return "0"
size_name = ("", "K", "M", "B", "T", "P", "E", "Z", "Y")
Expand All @@ -13,85 +18,162 @@ def convert_params(params):
s = round(params / p, 2)
return "%s %s" % (s, size_name[i])

def set_defaults(args):
'''
Sets the default values for the arguments that are not provided
'''
for key, value in DEFAULTS.items():
if getattr(args, key) is None:
setattr(args, key, value)
return args

def set_if_none(args, key, config, config_key):
'''
Sets the value of the argument to the default value if it is not provided
'''
if getattr(args, key) is None:
setattr(args, key, config.get(config_key, DEFAULTS[key]))
else:
print(f"overriding HF {config_key} config value ({config[config_key]}) with provided value ({getattr(args, key)})")
return args

def get_hf_model_args(args):
'''
Updates the args with HuggingFace model config values
'''
# Check if the name is not None
if args.hf_model_name_or_path is not None:
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(args.hf_model_name_or_path,
trust_remote_code=True).to_dict()
except OSError as e:
print("An OSError has been raised. Commonly due to a model Repository name or path not found. Are you sure it exists?")
print('Full error: ')
raise e
except ImportError as e:
print('If you would like to calculate from a HF model, you must install HF transformers with pip install transformers')
print('Full error: ')
raise e

# Now that config has been retrieved, we update the args with the config values

arch = config['model_type']

# Seperate handling for gpt2 because they named everything differently
if arch.lower()=='gpt2':
args.num_layers = config.get("n_layer", args.num_layers)
args.num_attention_heads = config.get("n_head", args.num_attention_heads)
args.hidden_size = config.get("n_embd", args.hidden_size)
args.vocab_size = config.get("vocab_size", args.vocab_size)

else:
set_if_none(args, "num_layers", config, "num_hidden_layers")
set_if_none(args, "num_attention_heads", config, "num_attention_heads")
set_if_none(args, "hidden_size", config, "hidden_size")

config["ffn_expansion_factor"] = config.get("intermediate_size", args.hidden_size) / args.hidden_size
set_if_none(args, "ffn_expansion_factor", config, "ffn_expansion_factor")

# config["num_key_value_heads"] = config.get("num_key_value_heads", config["num_attention_heads"])
# set_if_none(args, "num_key_value_heads", config, "num_key_value_heads")

set_if_none(args, "vocab_size", config, "vocab_size")
set_if_none(args, "sequence_length", config, "max_position_embeddings")

# Set the default values regardless
set_defaults(args)

return args

### End Helper Functions ###

### Begin Argument Parsing ###

def config_parser():
parser = argparse.ArgumentParser()
# HuggingFace Settings
parser.add_argument("--hf_model_name_or_path",
type=str,
default=None,
help="Name of the HuggingFace Hub repository or the local file path for it")
# Distributed Settings
parser.add_argument("--num-gpus",
type=int,
default=1,
default=None,
help='Number of GPUs used for training')
parser.add_argument("--tensor-parallel-size", "-tp",
type=int,
default=1,
default=None,
help='Tensor parallel degree (1 if not used)')
parser.add_argument("--pipeline-parallel-size", "-pp",
type=int,
default=1,
default=None,
help='Pipeline parallel degree (1 if not used)')
parser.add_argument("--partition-activations", "-pa",
action="store_true",
help='Whether we use ZeRO-R to partition activation memory across tensor-parallel degree')
parser.add_argument("--zero-stage", "-z",
type=int,
default=1,
default=None,
choices=[0,1,2,3],
help='Stage of the ZeRO optimizer')
parser.add_argument("--zero-allgather-bucket-size", "-zbs",
type=int,
default=5e8,
default=None,
help='Size of allgather buckets used by ZeRO')
parser.add_argument("--zero3-max-live-params", "-zmlp",
type=int,
default=1e9,
default=None,
help='Maximum number of parameters ZeRO3 keeps in GPU memory')
# Training settings
parser.add_argument("--checkpoint-activations", "-ca",
action="store_true",
help='Whether Megatron-style activation checkpointing is being used')
parser.add_argument("--batch-size-per-gpu", "-b",
type=int,
default=1,
default=None,
help='Batch size per GPU')
parser.add_argument("--sequence-length", "-s",
type=int,
default=2048,
default=None,
help='Sequence length used for training')
parser.add_argument("--vocab-size", "-v",
type=int,
default=51200,
default=None,
help='How many tokens are in the embedding layer')
# Model settings
parser.add_argument("--hidden-size", "-hs",
type=int,
default=6144,
default=None,
help='Dimension of the model\'s hidden size')
parser.add_argument("--num-attention-heads", "-a",
type=int,
default=64,
default=None,
help='Number of attention heads used in model')
parser.add_argument("--num-layers", "-l",
type=int,
default=44,
default=None,
help='Number of transformer layers used in model')
parser.add_argument("--ffn-expansion-factor", "-ff",
type=int,
default=4,
default=None,
help='How much the MLP hidden size expands')
parser.add_argument("--num-mlp-linears", "-nl",
type=int,
default=2,
default=None,
help='How many linear layers per MLP block')
# Inference settings
parser.add_argument("--infer",
action="store_true",
help="whether we're doing inference")
parser.add_argument("--kv-size-ratio", "-kv",
type=float,
default=1.0,
default=None,
help='Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA.')
parser.add_argument("--output-tokens", "-o",
type=int,
default=1,
default=None,
help='Number of tokens to autoregressively generate.')
# Precision settings
parser.add_argument("--disable-mixed-precision",
Expand All @@ -100,37 +182,79 @@ def config_parser():
dest='is_mixed_precision')
parser.add_argument("--high-prec-bytes-per-val",
type=int,
default=4,
default=None,
help='The high-precision bytes per value (parameter, optimizer state, etc) in mixed precision')
parser.add_argument("--low-prec-bytes-per-val",
type=int,
default=2,
default=None,
help='The low-precision bytes per value (parameter, optimizer state, etc) in mixed precision')
parser.add_argument("--bytes-per-grad-ele",
type=int,
default=4,
default=None,
help='The precision of gradient elements as bytes per value')
# MoE Settings
parser.add_argument("--num-experts",
type=int,
default=0,
default=None,
help='Number of experts')
parser.add_argument("--expert-parallelism", "-ep",
type=int,
default=1,
default=None,
help='How many ways are the experts sharded across ranks')
# Miscellaneous memory (good for accounting for implementation-dependent fudge factors)
parser.add_argument("--misc-mem-gib",
type=int,
default=0,
default=None,
help='Miscellaneous memory overhead per GPU by DL framework(s), communication libraries, etc')

return parser

DEFAULTS = {
# Distributed Settings
"num_gpus" : 1,
"tensor_parallel_size" : 1,
"pipeline_parallel_size" : 1,
"partition_activations" : False,
"zero_stage" : 1,
"zero_allgather_bucket_size" : 5e8,
"zero3_max_live_params" : 1e9,
# Training Settings
"checkpoint_activations" : False,
"batch_size_per_gpu" : 1,
"sequence_length" : 2048,
"vocab_size" : 51200,
# Model Settings
"hidden_size" : 6144,
"num_attention_heads" : 64,
"num_layers" : 44,
"ffn_expansion_factor" : 4,
"num_mlp_linears": 2,
# Inference Settings
"infer" : False,
"kv_size_ratio" : 1.0,
"output_tokens" : 1,
# Precision Settings
"is_mixed_precision" : True,
"high_prec_bytes_per_val" : 4,
"low_prec_bytes_per_val" : 2,
"bytes_per_grad_ele" : 4,
# MoE Settings
"num_experts" : 0,
"expert_parallelism" : 1,
# Miscellaneous Memory
"misc_mem_gib" : 0
}

### End Argument Parsing ###

### Begin Memory Calculation ###

# Calculates the total memory necessary for model training or inference
def calc_mem(args):

# set the hf_args if hf model is provided
args = get_hf_model_args(args)

dp_degree = args.num_gpus / (args.tensor_parallel_size * args.pipeline_parallel_size)

# Compute total parameters from the config
Expand Down Expand Up @@ -287,9 +411,11 @@ def calc_mem(args):
else:
print(f'\nTotal GPU Memory Required to Store a Complete Model Replica for Training: {single_replica_mem_gib:.2f} GiB')

### End Memory Calculation ###

if __name__ == "__main__":
print('\nExample with pythia 6.9B: python calc_transformer_mem.py --num-layers=32 --sequence-length=2048 --num-attention-heads=32 --hidden-size=4096 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=2 --num-gpus=128')
print('Example with pythia 12B: python calc_transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256')
print('Example with default 20B: python calc_transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1\n')
args = config_parser().parse_args()
calc_mem(args)
calc_mem(args)