Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple inference FLOP counter to calc_transformer_flops.py #31

Merged
merged 4 commits into from
Feb 20, 2024

Conversation

haileyschoelkopf
Copy link
Contributor

Output without --infer (same as before this PR):

python calc/calc_transformer_flops.py 

Example with Fairseq-MoE 15B: python calc_transformer_flops.py -l 12 -hs 768 --moe -e 512
Example with GPT-3 175B: python calc_transformer_flops.py -l 96 -hs 12288
Namespace(vocab_size=51200, hidden_size=6144, sequence_length=2048, num_layers=44, kv_size_ratio=1.0, moe=False, num_experts=128, expert_interval=2, topk=1, batch_size=1, tokens=300000000000.0, checkpoint_activations=True, infer=False)
Calculating number of FLOPs with training configuration: {'vocab_size': 51200, 'hidden_size': 6144, 'sequence_length': 2048, 'num_layers': 44, 'kv_size_ratio': 1.0, 'moe': False, 'num_experts': 128, 'expert_interval': 2, 'topk': 1, 'batch_size': 1, 'tokens': 300000000000.0, 'checkpoint_activations': True, 'infer': False}

QKV FLOPs: 11.96 ZFLOPs
Attention Matrix FLOPs: 1.33 ZFLOPs
Attention Over Values FLOPs: 1.33 ZFLOPs
Linear Projection FLOPs: 3.99 ZFLOPs
FFN FLOPs: 31.89 ZFLOPs
Embedding FLOPs: 566.23 EFLOPs
Total FLOPs for the Model: 51.06 ZFLOPs

Output with --infer:

> python calc/calc_transformer_flops.py --infer

Example with Fairseq-MoE 15B: python calc_transformer_flops.py -l 12 -hs 768 --moe -e 512
Example with GPT-3 175B: python calc_transformer_flops.py -l 96 -hs 12288
Namespace(vocab_size=51200, hidden_size=6144, sequence_length=2048, num_layers=44, kv_size_ratio=1.0, moe=False, num_experts=128, expert_interval=2, topk=1, batch_size=1, tokens=300000000000.0, checkpoint_activations=True, infer=True)
Calculating number of FLOPs with training configuration: {'vocab_size': 51200, 'hidden_size': 6144, 'sequence_length': 2048, 'num_layers': 44, 'kv_size_ratio': 1.0, 'moe': False, 'num_experts': 128, 'expert_interval': 2, 'topk': 1, 'batch_size': 1, 'tokens': 300000000000.0, 'checkpoint_activations': True, 'infer': True}

QKV FLOPs: 2.99 ZFLOPs
Attention Matrix FLOPs: 332.19 EFLOPs
Attention Over Values FLOPs: 332.19 EFLOPs
Linear Projection FLOPs: 996.57 EFLOPs
FFN FLOPs: 7.97 ZFLOPs
Embedding FLOPs: 566.23 EFLOPs
Total FLOPs for the Model: 13.19 ZFLOPs

inference cuts flop counts to 1/3 (if no activation ckpting) or 1/4 (if activation ckpting) as expected.

This is a very naive way of calculating "true" inference costs, though may hopefully be useful, especially if one only were wanting to run only a forward pass e.g. to get perplexity / loglikelihoods / embeddings on a dataset of X tokens.

@haileyschoelkopf haileyschoelkopf linked an issue Feb 19, 2024 that may be closed by this pull request
@Quentin-Anthony
Copy link
Member

This doesn't account for kv-caching (see https://kipp.ly/transformer-inference-arithmetic/), yes? We should either add a comment to that effect, or add kv-caching flops.

@haileyschoelkopf
Copy link
Contributor Author

Yep, this is prefill-only with no generated tokens for now--it should match what's used for calculations by https://arxiv.org/pdf/2401.00448.pdf and https://arxiv.org/abs/2304.03208 for inference FLOP budget-adjusted scaling laws.

Happy to add KV-caching flops if there's a good UX for it! not sure if it would clutter this mostly-training script (as we'd need to specify prefill sequence length, and total generated tokens too. I guess if there are no objections to having that extra # of generated tokens arg, adding it isn't too bad!). Would it make sense to start thinking about expanding to other useful inference napkin-math scripts / helpful tooling in that scenario?

@Quentin-Anthony
Copy link
Member

I think this comment clears it up sufficiently. Thanks!

@Quentin-Anthony Quentin-Anthony merged commit 56aeee1 into main Feb 20, 2024
1 check passed
@stas00
Copy link
Collaborator

stas00 commented Feb 20, 2024

Have you tried the new torch.utils.flop_counter which does it automatically and should do the right thing: https://gist.github.com/Chillee/07b36672a0ca2d1280e42b8d10f23174

@haileyschoelkopf
Copy link
Contributor Author

Haven't tried it but been meaning to! this might make sense if #1 were addressed, although I think we wouldn't want to require a forward pass to compute flops, for very large models

@stas00
Copy link
Collaborator

stas00 commented Feb 20, 2024

right, from that perspective, yes, the estimate is better

but the estimated flops could be quite off when torch.compile or fusion is used, no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inference FLOPs
3 participants