Skip to content

Commit 56aeee1

Browse files
Add simple inference FLOP counter to calc_transformer_flops.py (#31)
* add --infer arg to flops calculator * add comment * fix comment * Update calc_transformer_flops.py
1 parent 939fa3c commit 56aeee1

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

calc/calc_transformer_flops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ def config_parser():
6666
action='store_false',
6767
help='Whether Megatron-style activation checkpointing is being used',
6868
dest='checkpoint_activations')
69+
parser.add_argument("--infer", "-i",
70+
action='store_true',
71+
help='Pass to calculate FLOPs for inference-only workload (no backward pass)')
6972
return parser
7073

7174
# calculates the flops of a model given its hparams
7275
def calc_params(args):
73-
7476
assert args.topk <= args.num_experts, "You cannot route to more experts than you have!"
7577
assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers"
7678

@@ -82,6 +84,10 @@ def calc_params(args):
8284
iter_factor = 3
8385
if args.checkpoint_activations:
8486
iter_factor += 1
87+
# If inference-only, no bwd pass or activation ckpting necessary
88+
# This assumes simply running a single forward pass ('prefill' stage of decoding) and no subsequent autoregressively generated tokens.
89+
if args.infer:
90+
iter_factor = 1
8591

8692
qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size)
8793
attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size

0 commit comments

Comments
 (0)