From b08786a7263f8a9f1292ea21f84707b55152d118 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 13:22:47 -0800 Subject: [PATCH] [not for land yet] example of float8 with rowwise scaling Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... // torchao main branch step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // torchao with https://github.com/pytorch/ao/pull/1629 step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 6 ++++++ torchtitan/float8.py | 42 ++++++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3cc630c2e..b8e94fd91 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -577,6 +577,12 @@ def __init__(self): action="store_true", help="Whether precompute float8 scales dynamically for FSDP", ) + self.parser.add_argument( + "--float8.recipe_name", + type=str, + default=None, + help="If specified, creates float8 config from recipe name, choices are `all_axiswise` and `lw_axiswise_with_gw_hp`", + ) # communications library settings self.parser.add_argument( diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 849ac378f..1daf6dd8a 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -42,29 +42,39 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): return try: from torchao.float8 import Float8LinearConfig + # we should update this code after torchao exposes this publically + from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config except ImportError as e: raise ImportError( "torchao is not installed. Please install it to use float8 linear layers." ) from e - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - enable_fsdp_float8_all_gather = ( - parallel_dims.dp_shard_enabled - and float8_config.enable_fsdp_float8_all_gather - ) - self.config = Float8LinearConfig( - enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - ) - self.enabled = True - # for precompute_float8_dynamic_scale_for_fsdp - self.precompute_scale = ( - enable_fsdp_float8_all_gather - and float8_config.precompute_float8_dynamic_scale_for_fsdp - ) - - logger.info("Float8 training active") + if float8_config.recipe_name is not None: + # TODO(future PR): the recipe lookup by name is currently a private API, we'll need + # to expose it publically in torchao before a PR similar to this one can be + # landed in torchtitan + recipe = Float8LinearRecipeName(float8_config.recipe_name) + self.config = recipe_name_to_linear_config(recipe) + self.precompute_scale = False + logger.info(f"Float8 training active with recipe {recipe}") + + else: + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + enable_fsdp_float8_all_gather = ( + parallel_dims.dp_shard_enabled + and float8_config.enable_fsdp_float8_all_gather + ) + self.config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + ) + # for precompute_float8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + logger.info("Float8 tensorwise scaled training active") def convert_to_float8_training(self, model: nn.Module): """