Skip to content

Commit

Permalink
[not for land yet] example of float8 with rowwise scaling
Browse files Browse the repository at this point in the history
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Feb 7, 2025
1 parent 49c6d6f commit b08786a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,12 @@ def __init__(self):
action="store_true",
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--float8.recipe_name",
type=str,
default=None,
help="If specified, creates float8 config from recipe name, choices are `all_axiswise` and `lw_axiswise_with_gw_hp`",
)

# communications library settings
self.parser.add_argument(
Expand Down
42 changes: 26 additions & 16 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,39 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
return
try:
from torchao.float8 import Float8LinearConfig
# we should update this code after torchao exposes this publically
from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
)

self.enabled = True

# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)

logger.info("Float8 training active")
if float8_config.recipe_name is not None:
# TODO(future PR): the recipe lookup by name is currently a private API, we'll need
# to expose it publically in torchao before a PR similar to this one can be
# landed in torchtitan
recipe = Float8LinearRecipeName(float8_config.recipe_name)
self.config = recipe_name_to_linear_config(recipe)
self.precompute_scale = False
logger.info(f"Float8 training active with recipe {recipe}")

else:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
)
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
logger.info("Float8 tensorwise scaled training active")

def convert_to_float8_training(self, model: nn.Module):
"""
Expand Down

0 comments on commit b08786a

Please sign in to comment.