From 7d349b5e677a3b56aaf22bcf3b5b6ad638798dd8 Mon Sep 17 00:00:00 2001 From: haileyschoelkopf Date: Mon, 19 Feb 2024 10:22:18 -0500 Subject: [PATCH 1/4] add --infer arg to flops calculator --- calc/calc_transformer_flops.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 64d550d..a5ecd98 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -63,11 +63,14 @@ def config_parser(): action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') + parser.add_argument("--infer", "-i", + action='store_true', + help='Pass to calculate FLOPs for inference-only workload (no backward pass)') return parser # calculates the flops of a model given its hparams def calc_params(args): - + print(args) assert args.topk <= args.num_experts, "You cannot route to more experts than you have!" assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" @@ -79,6 +82,9 @@ def calc_params(args): iter_factor = 3 if args.checkpoint_activations: iter_factor += 1 + # If inference-only, no bwd pass or activation ckpting necessary + if args.infer: + iter_factor = 1 qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size) attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size From 47eae1aa2f4891a2fae2146576826dfb4712856b Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Mon, 19 Feb 2024 20:04:11 -0500 Subject: [PATCH 2/4] add comment --- calc/calc_transformer_flops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index a5ecd98..b3c711b 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -70,7 +70,6 @@ def config_parser(): # calculates the flops of a model given its hparams def calc_params(args): - print(args) assert args.topk <= args.num_experts, "You cannot route to more experts than you have!" assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" @@ -83,6 +82,8 @@ def calc_params(args): if args.checkpoint_activations: iter_factor += 1 # If inference-only, no bwd pass or activation ckpting necessary + # This assumes simply running a single forward pass ('prefill' stage of decoding) and no generated tokens. + # Or, if using a KV cache, this flop count will also be accurate. if args.infer: iter_factor = 1 From c20995a9ba817948c50fa28323baa9d2417486d8 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Mon, 19 Feb 2024 20:10:23 -0500 Subject: [PATCH 3/4] fix comment --- calc/calc_transformer_flops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index b3c711b..0f63a38 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -83,7 +83,6 @@ def calc_params(args): iter_factor += 1 # If inference-only, no bwd pass or activation ckpting necessary # This assumes simply running a single forward pass ('prefill' stage of decoding) and no generated tokens. - # Or, if using a KV cache, this flop count will also be accurate. if args.infer: iter_factor = 1 From 05c40662afa6dde5b78e09e9db4cd9df730b50c7 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Mon, 19 Feb 2024 20:12:44 -0500 Subject: [PATCH 4/4] Update calc_transformer_flops.py --- calc/calc_transformer_flops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 0f63a38..c61ef02 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -82,7 +82,7 @@ def calc_params(args): if args.checkpoint_activations: iter_factor += 1 # If inference-only, no bwd pass or activation ckpting necessary - # This assumes simply running a single forward pass ('prefill' stage of decoding) and no generated tokens. + # This assumes simply running a single forward pass ('prefill' stage of decoding) and no subsequent autoregressively generated tokens. if args.infer: iter_factor = 1