diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index e32f64cf1c..7b8e1a86dd 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -18,6 +18,7 @@ from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh +from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext @@ -36,8 +37,14 @@ def forward(self, x): def save_custom_attrs(module): custom_attrs = {} for name, param in module.named_parameters(): + if isinstance(param, QuantizedTensor): + # Ignore FP8 metadata attributes. Otherwise we will save duplicate copies + # for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save. + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] + else: + ignore_keys = [] attrs = vars(param) - custom_attrs[name] = {k: v for k, v in attrs.items()} + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} return custom_attrs @@ -104,24 +111,20 @@ def _train(args): # FP8 Configuration fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - + build_model_context_args = {} if not args.fp8_init: # Build model context (FP8 init) build_model_context = nullcontext - build_model_context_args = {} - + else: from transformer_engine.pytorch import fp8_model_init build_model_context = fp8_model_init build_model_context_args["enabled"] = True - - # Build the model with the specified context - with build_model_context(**build_model_context_args): - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) - else: - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + build_model_context_args["recipe"] = fp8_recipe # Move the model to the correct device - + # Build the model with the specified context + with build_model_context(**build_model_context_args): + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) model.to(device) if LOCAL_RANK == 0: @@ -146,7 +149,6 @@ def _train(args): ) else: assert False - # Apply FSDP/HSDP custom_attrs = save_custom_attrs(model) for sub_module in model.modules(): @@ -163,13 +165,14 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() input_data = torch.randn(args.batch_size, args.input_size).to(device) - output = model(input_data) + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) loss.backward() optimizer.step() if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed with loss {loss.item()}") dist.destroy_process_group() if LOCAL_RANK == 0: diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a4e68e53b0..4651b55d86 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,7 +4,8 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Union +import os +from typing import Any, Optional, Tuple, Iterable, Union import warnings import torch @@ -537,8 +538,34 @@ def remove_caches(self) -> None: self._transpose = None @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + transpose: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. + + """ + new_tensor = super().make_like( + tensor, shape=shape, dtype=dtype, requires_grad=requires_grad + ) + if data is not None: + new_tensor._data = data + if transpose is not None and not tensor._transpose_invalid: + new_tensor._transpose = transpose + new_tensor._transpose_invalid = False + return new_tensor + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): # View op if func == aten.view.default: tensor = args[0] @@ -590,11 +617,37 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [ - Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) - for split_tensor in func_out + t_func_out = [None] * len(func_out) + # Compute corresponding split of the transpose cache if available + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + nd = data.dim() + # Figure out the original split dim + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + # Transpose dim is reversed + t_dim = nd - 1 - dim_to_split + t_func_out = transpose.__torch_dispatch__( + func, + types, + [transpose, args[1], t_dim], + kwargs, + ) + outs = [ + Float8Tensor.make_like( + tensor, + data=split_tensor, + transpose=split_tranpose_tensor, + shape=split_tensor.shape, + ) + for split_tensor, split_tranpose_tensor in zip(func_out, t_func_out) ] + return outs + if func == aten.new_zeros.default: + # create fresh new tensor with zeros. tensor = args[0] data = tensor._data func_out = data.__torch_dispatch__( @@ -603,17 +656,50 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + # deep copy the scale inverse tensor and quantizer as well. + scale_inv = tensor._scale_inv.detach().clone() + quantizer = tensor._quantizer.copy() + out_tensor = Float8Tensor( + data=func_out, + shape=data.shape, + dtype=tensor.dtype, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=scale_inv, + quantizer=quantizer, + ) + return out_tensor + if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data + # Apply as_strided to the primary uint8 data func_out = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + stride = args[2] + if "storage_offset" in kwargs: + storage_offset = kwargs["storage_offset"] + else: + storage_offset = args[3] if len(args) > 3 else 0 + t_size = list(reversed(size)) if len(size) > 0 else size + t_stride = list(reversed(stride)) if len(stride) > 0 else stride + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_size, t_stride, storage_offset] + list(args[4:]), + kwargs, + ) + return Float8Tensor.make_like( + tensor, data=func_out, transpose=func_transposed_out, shape=func_out.shape + ) + if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: @@ -635,9 +721,74 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) else: pass - return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + + Returns: + shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered.(In this case uint8 data tensor) + metadata: Tuple[Any]: Metadata needed for reconstructing the + Float8Tensor after all-gather. + """ + quantizer = self._quantizer + if isinstance(quantizer, Float8CurrentScalingQuantizer) and mesh is not None: + # When sharded weight is updated after reduce scattering the gradients in FSDP2, + # we need to do amax reduction across the mesh to make sure all weight shards are + # updated with same scale inverse. Setting the state below in the quantizer will make + # sure that updated Quantized weight tensor have same scale inverse across all shards. + quantizer.amax_reduction_group = mesh.get_group() + quantizer.with_amax_reduction = True + sharded_tensors = (self._data,) + metadata = (self._scale_inv, self._fp8_dtype, self.dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor. + param_dtype (torch.dtype): + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + + Returns: + Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors + used by the Float8Tensor that was actually allgathered + """ + if out is not None: + # The Float8tensor object returned in the post_all_gather is used over and over again. + # So no need to create a new one. + # In torchao implementation of Float8Tensor, this condition is used to set the scale inverse. + # since scale inverse is set after the allgather. However we take care of it during quantization + # itself by passing the amax reduction group to the quantizer. + return + (data,) = all_gather_outputs + (fp8_scale_inv, fp8_dtype, fake_dtype, quantizer) = metadata + out = Float8Tensor( + data=data, + fp8_scale_inv=fp8_scale_inv, + fp8_dtype=fp8_dtype, + shape=data.shape, + dtype=fake_dtype, + quantizer=quantizer, + ) + out._create_transpose() + return out, (data,) + @classmethod def _make_in_reduce_ex( cls, diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index a524d5c8de..ba7e3932db 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -495,6 +495,10 @@ def maybe_update_inplace(arg, new_arg, schema_arg): and schema_arg.alias_info.is_write ): arg.quantize_(new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + # Recursively handle update for lists of tensors + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) # In-place op: dequantize, perform op, and quantize if func._schema.is_mutable: @@ -563,8 +567,6 @@ def make_like( shape = data.shape if data is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() - if data is not None: - kwargs["data"] = data return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: