File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -66,11 +66,13 @@ def config_parser():
66
66
action = 'store_false' ,
67
67
help = 'Whether Megatron-style activation checkpointing is being used' ,
68
68
dest = 'checkpoint_activations' )
69
+ parser .add_argument ("--infer" , "-i" ,
70
+ action = 'store_true' ,
71
+ help = 'Pass to calculate FLOPs for inference-only workload (no backward pass)' )
69
72
return parser
70
73
71
74
# calculates the flops of a model given its hparams
72
75
def calc_params (args ):
73
-
74
76
assert args .topk <= args .num_experts , "You cannot route to more experts than you have!"
75
77
assert args .num_layers % args .expert_interval == 0 , "Require for simplicity that we don't have hanging dense layers"
76
78
@@ -82,6 +84,10 @@ def calc_params(args):
82
84
iter_factor = 3
83
85
if args .checkpoint_activations :
84
86
iter_factor += 1
87
+ # If inference-only, no bwd pass or activation ckpting necessary
88
+ # This assumes simply running a single forward pass ('prefill' stage of decoding) and no subsequent autoregressively generated tokens.
89
+ if args .infer :
90
+ iter_factor = 1
85
91
86
92
qkv_flops = int (iter_factor * 2 * (1 + 2 * args .kv_size_ratio ) * args .num_layers * args .tokens * args .hidden_size * args .hidden_size )
87
93
attention_matrix_flops = iter_factor * 2 * args .num_layers * args .tokens * args .sequence_length * args .hidden_size
You can’t perform that action at this time.
0 commit comments