Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[not for land yet] example of float8 with rowwise scaling
Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... // torchao main branch step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // torchao with pytorch/ao#1629 step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
- Loading branch information