diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3cc630c2..b8e94fd9 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -577,6 +577,12 @@ def __init__(self): action="store_true", help="Whether precompute float8 scales dynamically for FSDP", ) + self.parser.add_argument( + "--float8.recipe_name", + type=str, + default=None, + help="If specified, creates float8 config from recipe name, choices are `all_axiswise` and `lw_axiswise_with_gw_hp`", + ) # communications library settings self.parser.add_argument( diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 849ac378..1daf6dd8 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -42,29 +42,39 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): return try: from torchao.float8 import Float8LinearConfig + # we should update this code after torchao exposes this publically + from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config except ImportError as e: raise ImportError( "torchao is not installed. Please install it to use float8 linear layers." ) from e - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - enable_fsdp_float8_all_gather = ( - parallel_dims.dp_shard_enabled - and float8_config.enable_fsdp_float8_all_gather - ) - self.config = Float8LinearConfig( - enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - ) - self.enabled = True - # for precompute_float8_dynamic_scale_for_fsdp - self.precompute_scale = ( - enable_fsdp_float8_all_gather - and float8_config.precompute_float8_dynamic_scale_for_fsdp - ) - - logger.info("Float8 training active") + if float8_config.recipe_name is not None: + # TODO(future PR): the recipe lookup by name is currently a private API, we'll need + # to expose it publically in torchao before a PR similar to this one can be + # landed in torchtitan + recipe = Float8LinearRecipeName(float8_config.recipe_name) + self.config = recipe_name_to_linear_config(recipe) + self.precompute_scale = False + logger.info(f"Float8 training active with recipe {recipe}") + + else: + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + enable_fsdp_float8_all_gather = ( + parallel_dims.dp_shard_enabled + and float8_config.enable_fsdp_float8_all_gather + ) + self.config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + ) + # for precompute_float8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + logger.info("Float8 tensorwise scaled training active") def convert_to_float8_training(self, model: nn.Module): """