diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 64d550d..c61ef02 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -63,11 +63,13 @@ def config_parser(): action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') + parser.add_argument("--infer", "-i", + action='store_true', + help='Pass to calculate FLOPs for inference-only workload (no backward pass)') return parser # calculates the flops of a model given its hparams def calc_params(args): - assert args.topk <= args.num_experts, "You cannot route to more experts than you have!" assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" @@ -79,6 +81,10 @@ def calc_params(args): iter_factor = 3 if args.checkpoint_activations: iter_factor += 1 + # If inference-only, no bwd pass or activation ckpting necessary + # This assumes simply running a single forward pass ('prefill' stage of decoding) and no subsequent autoregressively generated tokens. + if args.infer: + iter_factor = 1 qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size) attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size