Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext


Expand All @@ -36,8 +37,13 @@ def forward(self, x):
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes
ignore_keys = ["_" + k for k in param.get_metadata().keys()]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs


Expand Down Expand Up @@ -104,24 +110,20 @@ def _train(args):
# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

build_model_context_args = {}
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}

else:
from transformer_engine.pytorch import fp8_model_init

build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
build_model_context_args["recipe"] = fp8_recipe
# Move the model to the correct device

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
model.to(device)

if LOCAL_RANK == 0:
Expand All @@ -146,7 +148,6 @@ def _train(args):
)
else:
assert False

# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
Expand All @@ -163,13 +164,14 @@ def _train(args):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed with loss {loss.item()}")

dist.destroy_process_group()
if LOCAL_RANK == 0:
Expand Down
71 changes: 68 additions & 3 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
import os
from typing import Any, Optional, Tuple, Iterable, Union
import warnings

import torch
Expand Down Expand Up @@ -538,7 +539,6 @@ def remove_caches(self) -> None:

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

# View op
if func == aten.view.default:
tensor = args[0]
Expand Down Expand Up @@ -603,7 +603,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
[data] + list(args[1:]),
kwargs,
)
# here it should be deep copy since it is not a view op
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)

if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
Expand Down Expand Up @@ -635,9 +637,72 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
)
else:
pass

return super().__torch_dispatch__(func, types, args, kwargs)

def fsdp_pre_all_gather(self, mesh):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.

Returns:
shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.(In this case uint8 data tensor)
metadata: Tuple[Any]: Metadata needed for reconstructing the
Float8Tensor after all-gather.
"""
quantizer = self._quantizer
if isinstance(quantizer, Float8CurrentScalingQuantizer) and mesh is not None:
# When sharded weight is updated after reduce scattering the gradients in FSDP2,
# we need to do amax reduction across the mesh to make sure all weight shards are
# updated with same scale inverse. Setting the state below in the quantizer will make
# sure that updated Quantized weight tensor have same scale inverse across all shards.
quantizer.amax_reduction_group = mesh.get_group()
quantizer.with_amax_reduction = True
sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, self.dtype, quantizer)
return sharded_tensors, metadata

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor.
param_dtype (torch.dtype):
out (Optional[torch.Tensor], optional): _description_. Defaults to None.

Returns:
Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors
used by the Float8Tensor that was actually allgathered
"""
if out is not None:
# The Float8tensor object returned in the post_all_gather is used over and over again.
# So no need to create a new one.
# In torchao implementation of Float8Tensor, this condition is used to set the scale inverse.
# since scale inverse is set after the allgather. However we take care of it during quantization
# itself by passing the amax reduction group to the quantizer.
return
(data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, fake_dtype, quantizer) = metadata
return Float8Tensor(
data=data,
fp8_scale_inv=fp8_scale_inv,
fp8_dtype=fp8_dtype,
shape=data.shape,
dtype=fake_dtype,
quantizer=quantizer,
), (data,)

@classmethod
def _make_in_reduce_ex(
cls,
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/tensor/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ def maybe_update_inplace(arg, new_arg, schema_arg):
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
elif isinstance(arg, list) and isinstance(new_arg, list):
# Recursively handle update for lists of tensors
for a, na in zip(arg, new_arg):
maybe_update_inplace(a, na, schema_arg)

# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
Expand Down