From 0b93e9b06944afe6350d955c3e76954c17ad2378 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 11 Mar 2025 17:31:06 -0400 Subject: [PATCH 01/16] wip --- fast_llm/engine/multi_stage/config.py | 7 + fast_llm/engine/multi_stage/fsdp.py | 323 ++++++++++++++ fast_llm/engine/multi_stage/stage_base.py | 501 +++++++--------------- 3 files changed, 494 insertions(+), 337 deletions(-) create mode 100644 fast_llm/engine/multi_stage/fsdp.py diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d6997105..8bd191cc 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -70,6 +70,13 @@ class StageConfig(Config): desc="Reduce and accumulate gradients in fp32 to improve numerical stability.", hint=FieldHint.optional, ) + store_frozen_weights_in_optimization_precision: bool = Field( + default=False, + desc="Store frozen weights in full precision even if not not needed." + "Allows preserving the precision for saved checkpoints," + " at the cost of memory and compute (copy) overheads.", + hint=FieldHint.optional, + ) debug_layer_outputs: int = Field( default=0, desc="Log the output of each layer.", diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py new file mode 100644 index 00000000..e4eab333 --- /dev/null +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -0,0 +1,323 @@ +import typing + +import torch + +from fast_llm.core.distributed import ProcessGroup +from fast_llm.core.ops import gather_op +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.distributed.config import DistributedDim +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageMode +from fast_llm.logging import log_distributed_tensor +from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta +from fast_llm.utils import Assert, clamp, padded_cumsum + + +class FSDP: + _is_setup: bool = False + _fsdp_group: ProcessGroup + + def __init__( + self, + name: str, + parameter_metas: list[ParameterMeta], + fsdp_dim: DistributedDim, + training_dtype: DataType, + gradient_buffer_dtype: DataType, + optimization_dtype: DataType, + ): + self._name = name + self._parameter_metas = {parameter_meta.tensor_name: parameter_meta for parameter_meta in parameter_metas} + self._fsdp_dim = fsdp_dim + self._training_dtype = training_dtype + self._gradient_buffer_dtype = gradient_buffer_dtype + self._optimization_dtype = optimization_dtype + + parameter_sizes = [meta.numel() for meta in self._parameter_metas.values()] + self._parameter_count = sum(parameter_sizes) + parameter_offsets = padded_cumsum(parameter_sizes).tolist() + + # The index range of the parameters in the buffer. + self._parameter_begins_in_buffer = { + parameter_meta.tensor_name: offset + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1]) + } + self._parameter_ends_in_buffer = { + parameter_meta.tensor_name: offset + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:]) + } + + # Shard properties + # We pad the stage so that each shard has the same size + # and is a multiple of SHARD_PAD_TO_MULTIPLE (for data alignment) + self._global_pad = -self._parameter_count % (self._fsdp_dim.size * SHARD_PAD_TO_MULTIPLE) + self._shard_size = (self._parameter_count + self._global_pad) // self._fsdp_dim.size + # Typically the padding is all on the last shard, but in some cases it can overflow to other shards. + self._shard_pad = min( + max(self._global_pad - self._shard_size * (self._fsdp_dim.size - self._fsdp_dim.rank - 1), 0), + self._shard_size, + ) + + # TODO: Use parallel_dim property instead? + shard_dim = TensorDim("flat_shard", (self._parameter_count + self._global_pad) // self._fsdp_dim.size) + buffer_dim = TensorDim("flat_buffer", shard_dim.size * self._fsdp_dim.size) + + self._weight_shard_meta = TensorMeta.from_dims( + (shard_dim,), + tensor_name=f"{self._name}_weight_shard", + dtype=self._optimization_dtype.torch, + ) + self._grad_shard_meta = TensorMeta.from_dims( + (shard_dim,), + tensor_name=f"{self._name}_grad_shard", + dtype=self._optimization_dtype.torch, + ) + self._weight_buffer_meta = TensorMeta.from_dims( + (buffer_dim,), + tensor_name=f"{self._name}_weight_buffer", + dtype=self._training_dtype.torch, + ) + self._grad_buffer_meta = TensorMeta.from_dims( + (buffer_dim,), + tensor_name=f"{self._name}_weight_buffer", + dtype=self._gradient_buffer_dtype.torch, + ) + + @property + def parameter_names(self) -> list[str]: + return list(self._parameter_metas) + + @property + def parameter_count(self) -> int: + return self._parameter_count + + def __contains__(self, parameter_name: str) -> bool: + return parameter_name in self.parameter_names + + def get_parameter_meta(self, parameter_name: str) -> ParameterMeta: + return self._parameter_metas[parameter_name] + + def get_parameter_buffer(self, parameter_name: str) -> torch.nn.Parameter: + assert self._is_setup + assert self._mode.support_forward + return self._parameter_buffers[parameter_name] + + def setup( + self, + mode: StageMode, + fsdp_group: ProcessGroup, + weight_shard: torch.Tensor | None, + grad_shard: torch.Tensor | None, + weight_buffer: torch.Tensor | None, + grad_buffer: torch.Tensor | None, + sequence_tensor_parallel: bool = False, + ) -> None: + assert not self._is_setup + self._is_setup = True + self._fsdp_group = fsdp_group + self._mode = mode + + # Validate and set the shards and buffers + if self._mode.on_device: + self._weight_shard = self._weight_shard_meta.validate(weight_shard) + else: + Assert.none(weight_shard) + if self._mode.support_forward: + self._weight_buffer = self._weight_buffer_meta.validate(weight_buffer) + # Pre-compute the local shard for restore ops. + self._weight_buffer_local_shard = self._weight_buffer[ + self._fsdp_dim.rank * self._shard_size : (self._fsdp_dim.rank + 1) * self._shard_size + ] + else: + Assert.none(weight_buffer) + + if self._mode.support_backward: + self._grad_shard = self._grad_shard_meta.validate(grad_shard) + self._grad_buffer = self._grad_buffer_meta.validate(grad_buffer) + # Pre-compute the local shard for reduce ops. + self._grad_buffer_local_shard = self._grad_buffer[ + self._fsdp_dim.rank * self._shard_size : (self._fsdp_dim.rank + 1) * self._shard_size + ] + # Pre-compute the sequence-parallel grads. + sp_indices = [i for i, meta in enumerate(self._parameter_metas.values()) if meta.sequence_tensor_parallel] + if sp_indices and sequence_tensor_parallel: + Assert.eq(sp_indices, list(range(sp_indices[0], sp_indices[-1] + 1))) + sp_indices = [ + i for i, meta in enumerate(self._parameter_metas.values()) if meta.sequence_tensor_parallel + ] + sp_begin, sp_end = ( + list(self._parameter_begins_in_buffer.values())[sp_indices[0]], + list(self._parameter_ends_in_buffer.values())[sp_indices[-1]], + ) + else: + sp_begin, sp_end = 0, 0 + self._sequence_parallel_grads = self._grad_buffer[sp_begin:sp_end] if sp_end > sp_begin else None + + else: + Assert.none(grad_shard) + Assert.none(grad_buffer) + + if self._mode.support_forward: + # Precompute the buffer slice for each parameter. + # Use `.data` to hide the restore ops from autograd. + self._parameter_buffers = {} + for weight_buffer, grad_buffer, parameter_name in zip( + self.split_buffer(self._weight_buffer.data).values(), + self.split_buffer( + self._grad_buffer if self._mode.support_backward else self._grad_buffer_meta + ).values(), + self._parameter_metas, + ): + parameter_buffer = torch.nn.Parameter(weight_buffer, requires_grad=self._mode.support_backward) + if self._mode.support_backward: + parameter_buffer.grad_buffer = grad_buffer + # TODO: This is only needed for Megatron initialization + self._parameter_buffers[parameter_name] = parameter_buffer + + def reset_shard_pad(self, shard: torch.Tensor) -> int: + assert self._is_setup + assert self._mode.on_device + # TODO: Needed? + # Prevent nans with the padded values + # Also ensures a correct parameter count in loading context. + self._weight_shard_meta.validate(shard) + if self._shard_pad > 0: + shard[-self._shard_pad :].zero_() + return self._shard_pad + return 0 + + def split_buffer(self, buffer: torch.Tensor) -> dict[str, torch.Tensor]: + # Split a buffer into appropriately shaped parameters. + return { + name: buffer[self._parameter_begins_in_buffer[name] : self._parameter_ends_in_buffer[name]].view( + meta.shape + ) + for name, meta in self._parameter_metas.items() + } + + def split_shard(self, shard: torch.Tensor) -> dict[str, torch.Tensor]: + # Split a shard into flat (possibly empty) parameter slices. + return { + name: shard[ + self._index_buffer_to_shard(self._parameter_begins_in_buffer[name]) : self._parameter_ends_in_buffer[ + name + ] + ] + for name in self._parameter_metas + } + + def _index_buffer_to_shard(self, index: int, rank: int | None = None) -> int: + shard_begin = (self._fsdp_dim.rank if rank is None else rank) * self._shard_size + return clamp(index - shard_begin, 0, self._shard_size - self._shard_pad) + + def _index_buffer_to_param(self, index: int, parameter_name: str) -> int: + return clamp( + index - self._parameter_begins_in_buffer[parameter_name], 0, self._parameter_metas[parameter_name].numel() + ) + + def reconstruct_from_shard(self, local_shard: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor: + return gather_op(local_shard, group=self._fsdp_group, dim=0, out=out) + + def import_state_tensor( + self, parameter_name: str, shard: torch.Tensor, tensor: torch.Tensor | SafeTensorSlice + ) -> int: + """ + Given a global parameter tensor, set the associated slice of a local parameter shard. + Return the size of the local slice. + """ + Assert.eq(shard.shape, (self._shard_size,)) + tensor_shard = self._parameter_global_to_shard(tensor, parameter_name) + begin, end = self._parameter_range_in_shard(parameter_name) + Assert.eq(tensor_shard.numel(), end - begin) + shard[begin:end].copy_(tensor_shard) + return end - begin + + def export_shard( + self, shard: torch.Tensor, distributed: Distributed, data_type: DataType | None = None + ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: + if data_type is not None: + shard = shard.to(dtype=data_type.torch) + tensors = self.split_buffer(self.reconstruct_from_shard(shard)) + for name, meta in self._parameter_metas.items(): + yield name, meta.local_to_global(tensors[name], distributed=distributed)[0] + + def log_shard(self, name, shard, *, distributed: Distributed, level, global_: bool) -> None: + # if global_ is None: + # global_ = self._config.debug_global_tensors + parameters = self.split_buffer(self.reconstruct_from_shard(shard)) if global_ else self.split_shard(shard) + for parameter_name, parameter in parameters.items(): + log_distributed_tensor( + name, + parameter, + level=level, + distributed=distributed, + global_=global_, + duplicate_groups=(distributed.data_group,), + meta=self.get_parameter_meta(parameter_name), + ) + + def _parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: + begin = self._index_buffer_to_shard(self._parameter_begins_in_buffer[parameter_name]) + end = self._index_buffer_to_shard(self._parameter_ends_in_buffer[parameter_name]) + return begin, end + + def _parameter_global_to_shard( + self, global_param: torch.Tensor | SafeTensorSlice, parameter_name: str + ) -> torch.Tensor: + shard_param = self.get_parameter_meta(parameter_name).global_to_local(global_param).flatten() + if self._fsdp_dim.size > 1: + shard_param = shard_param[ + self._index_buffer_to_param( + self._fsdp_dim.rank * self._shard_size, parameter_name + ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) + ] + return shard_param + + def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, device: torch.device) -> torch.Tensor: + """ + Create an index array for the global parameter, where each entry corresponds to the index + where it is located in the shard if it exists, or -1 if it's not in the shard. + Used to determine the location of each entry in a different distributed configuration. + """ + parameter_meta = self.get_parameter_meta(parameter_name) + + # Create an empty index for the global parameter. + index = torch.full( + parameter_meta.global_shape, + -1, + dtype=torch.int64, + device=device, + ) + # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard + begin, end = self._parameter_range_in_shard(parameter_name) + self._parameter_global_to_shard(index, parameter_name).copy_( + torch.arange(begin, end, dtype=torch.int64, device=device) + ) + return index + + def _copy_shard_overlaps( + self, + loaded_fsdp: "FSDP", + shards: list[torch.Tensor], + loaded_shards: list[torch.Tensor], + counter: torch.Tensor, + device: torch.device, + ) -> None: + """ + See MultiStage._load_partial. + TODO: Not intended to work with frozen weights, need to enforce. + """ + index_overlap = [name for name in loaded_fsdp._parameter_metas if name in self._parameter_metas] + for name in index_overlap: + overlap_index_map = self._parameter_global_to_shard( + loaded_fsdp._get_parameter_shard_indices_in_full_weight(name, device), name + ) + overlap_mask = overlap_index_map >= 0 + overlap_index_map_masked = overlap_index_map[overlap_mask] + overlap_count = overlap_mask.sum() + begin, end = self._parameter_range_in_shard(name) + + for shard, loaded_shard in zip(shards, loaded_shards): + shard[begin:end][overlap_mask] = loaded_shard[overlap_index_map_masked] + counter += overlap_count diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9ef784fb..0131b2a1 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -6,17 +6,16 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import ProcessGroup, check_parallel_match -from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageConfig, StageMode +from fast_llm.engine.multi_stage.config import StageConfig, StageMode +from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.optimizer.config import ParamGroup -from fast_llm.logging import log_distributed_tensor, log_generator +from fast_llm.logging import log_generator from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta -from fast_llm.utils import Assert, clamp, div, padded_cumsum +from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) @@ -29,14 +28,14 @@ class StageBase(Configurable[StageConfig]): _fsdp_group: ProcessGroup _mode: StageMode - _weight_shard: torch.Tensor - _grad_shard: torch.Tensor - _weight_buffer: torch.Tensor - _grad_buffer: torch.Tensor + # _weight_shard: torch.Tensor + # _grad_shard: torch.Tensor + # _weight_buffer: torch.Tensor + # _grad_buffer: torch.Tensor _sequence_parallel_grads: torch.Tensor - _weight_buffer_local_shard: torch.Tensor - _grad_buffer_local_shard: torch.Tensor - _parameter_buffers: list[torch.nn.Parameter] + # _weight_buffer_local_shard: torch.Tensor + # _grad_buffer_local_shard: torch.Tensor + # _parameter_buffers: list[torch.nn.Parameter] def __init__( self, @@ -50,7 +49,6 @@ def __init__( ): super().__init__(config) self._distributed_config = distributed_config.validate() - Assert.in_range(begin, 0, end) Assert.leq(end, len(base_model)) @@ -62,56 +60,41 @@ def __init__( self._layers = [torch.compile(layer) if self._config.compile_all else layer for layer in base_model[begin:end]] self._layer_range = list(range(begin, end)) - self._parameter_metas: list[ParameterMeta] = self._get_parameter_metas() - - self._parameter_index = {param.tensor_name: i for i, param in enumerate(self._parameter_metas)} - - parameter_sizes = [meta.numel() for meta in self._parameter_metas] - self._parameter_count = sum(parameter_sizes) - parameter_offsets = padded_cumsum(parameter_sizes).tolist() - - # The index range of the parameters in the buffer. - self._parameter_begins_in_buffer = parameter_offsets[:-1] - self._parameter_ends_in_buffer = parameter_offsets[1:] - - # Shard properties - # We pad the stage so that each shard has the same size - # and is a multiple of SHARD_PAD_TO_MULTIPLE (for data alignment) - self._global_pad = -self._parameter_count % (self._fsdp_size * SHARD_PAD_TO_MULTIPLE) - self._shard_size = (self._parameter_count + self._global_pad) // self._fsdp_size - # Typically the padding is all on the last shard, but in some cases it can overflow to other shards. - self._shard_pad = min( - max(self._global_pad - self._shard_size * (self._fsdp_size - self._fsdp_rank - 1), 0), self._shard_size - ) - - # TODO: Use parallel_dim property instead? - shard_dim = TensorDim("flat_shard", (self._parameter_count + self._global_pad) // self._fsdp_size) - buffer_dim = TensorDim("flat_buffer", shard_dim.size * self._fsdp_size) - - self._weight_shard_meta = TensorMeta.from_dims( - (shard_dim,), - tensor_name=f"stage_{self._index}_weight_shard", - dtype=self._distributed_config.optimization_dtype.torch, - ) - self._grad_shard_meta = TensorMeta.from_dims( - (shard_dim,), - tensor_name=f"stage_{self._index}_grad_shard", - dtype=self._distributed_config.optimization_dtype.torch, - ) - self._weight_buffer_meta = TensorMeta.from_dims( - (buffer_dim,), - tensor_name=f"stage_{self._index}_weight_buffer", - dtype=self._distributed_config.training_dtype.torch, - ) - self._grad_buffer_meta = TensorMeta.from_dims( - (buffer_dim,), - tensor_name=f"stage_{self._index}_weight_buffer", - dtype=( - self._distributed_config.optimization_dtype - if self._config.full_precision_gradients - else self._distributed_config.training_dtype - ).torch, - ) + parameter_metas, frozen_metas = self._get_parameter_metas() + self._parameter_metas = parameter_metas + frozen_metas + self._fsdps = [] + if parameter_metas: + self._fsdps.append( + FSDP( + f"stage_{self._index}", + parameter_metas, + self._distributed_config.get_distributed_dim(DistributedDimNames.data), + training_dtype=self._distributed_config.training_dtype, + gradient_buffer_dtype=( + self._distributed_config.optimization_dtype + if self._config.full_precision_gradients + else self._distributed_config.training_dtype + ), + optimization_dtype=self._distributed_config.optimization_dtype, + ) + ) + if frozen_metas: + self._fsdps.append( + FSDP( + f"stage_{self._index}_frozen", + frozen_metas, + self._distributed_config.get_distributed_dim(DistributedDimNames.data), + training_dtype=self._distributed_config.training_dtype, + gradient_buffer_dtype=self._distributed_config.training_dtype, + optimization_dtype=( + self._distributed_config.optimization_dtype + if self._config.store_frozen_weights_in_optimization_precision + else self._distributed_config.training_dtype.torch + ), + ) + ) + # TODO: Separate fsdp for tied weights? + self._fsdp_index = {name: i for i, fsdp in enumerate(self._fsdps) for name in fsdp.parameter_names} @property def mode(self) -> StageMode: @@ -122,51 +105,50 @@ def mode(self) -> StageMode: def index(self) -> int: return self._index - @property - def weight_shard_meta(self) -> TensorMeta: - return self._weight_shard_meta + # @property + # def weight_shard_meta(self) -> TensorMeta: + # return self._weight_shard_meta - @property - def grad_shard_meta(self) -> TensorMeta: - return self._grad_shard_meta + # @property + # def grad_shard_meta(self) -> TensorMeta: + # return self._grad_shard_meta - @property - def weight_buffer_meta(self) -> TensorMeta: - return self._weight_buffer_meta + # @property + # def weight_buffer_meta(self) -> TensorMeta: + # return self._weight_buffer_meta - @property - def grad_buffer_meta(self) -> TensorMeta: - return self._grad_buffer_meta + # @property + # def grad_buffer_meta(self) -> TensorMeta: + # return self._grad_buffer_meta - @property - def weight_shard(self) -> torch.Tensor: - # TODO: Avoid this method (needed for tied weights broadcast) - assert self._is_setup - assert self._mode.support_forward - return self._weight_shard + # @property + # def weight_shard(self) -> torch.Tensor: + # # TODO: Avoid this method (needed for tied weights broadcast) + # assert self._is_setup + # assert self._mode.support_forward + # return self._weight_shard - @property - def grad_shard(self) -> torch.Tensor: - # TODO: Avoid this method (needed for tied weights reduce) - assert self._is_setup - assert self._mode.support_backward - return self._grad_shard + # @property + # def grad_shard(self) -> torch.Tensor: + # # TODO: Avoid this method (needed for tied weights reduce) + # assert self._is_setup + # assert self._mode.support_backward + # return self._grad_shard @property def parameter_count(self) -> int: - return self._parameter_count + return sum(fsdp.parameter_count for fsdp in self._fsdps) @property def parameter_names(self) -> list[str]: - return list(self._parameter_index) + return sum((fsdp.parameter_names for fsdp in self._fsdps), []) def get_parameter_meta(self, parameter_name: str) -> ParameterMeta: - return self._parameter_metas[self._parameter_index[parameter_name]] + return self._fsdps[self._fsdp_index[parameter_name]].get_parameter_meta(parameter_name) - def get_parameter_buffer(self, meta: ParameterMeta) -> torch.nn.Parameter: + def get_parameter_buffer(self, parameter_name: str) -> torch.nn.Parameter: assert self._is_setup - assert self._mode.support_forward - return self._parameter_buffers[self._parameter_index[meta.tensor_name]] + return self._fsdps[self._fsdp_index[parameter_name]].get_parameter_buffer(parameter_name) def setup( self, @@ -183,66 +165,26 @@ def setup( self._is_setup = True self._distributed = distributed self._fsdp_group = self._distributed.data_group - self._mode = mode - # Validate and set the shards and buffers - if self._mode.on_device: - self._weight_shard = self._weight_shard_meta.validate(weight_shard) - else: - Assert.none(weight_shard) - if self._mode.support_forward: - self._weight_buffer = self._weight_buffer_meta.validate(weight_buffer) - # Pre-compute the local shard for restore ops. - self._weight_buffer_local_shard = self._weight_buffer[ - self._fsdp_rank * self._shard_size : (self._fsdp_rank + 1) * self._shard_size - ] - else: - Assert.none(weight_buffer) - - if self._mode.support_backward: - self._grad_shard = self._grad_shard_meta.validate(grad_shard) - self._grad_buffer = self._grad_buffer_meta.validate(grad_buffer) - # Pre-compute the local shard for reduce ops. - self._grad_buffer_local_shard = self._grad_buffer[ - self._fsdp_rank * self._shard_size : (self._fsdp_rank + 1) * self._shard_size - ] - # Pre-compute the sequence-parallel grads. - sp_indices = [i for i, meta in enumerate(self._parameter_metas) if meta.sequence_tensor_parallel] - if sp_indices and self._distributed_config.sequence_tensor_parallel: - Assert.eq(sp_indices, list(range(sp_indices[0], sp_indices[-1] + 1))) - sp_begin, sp_end = ( - self._parameter_begins_in_buffer[sp_indices[0]], - self._parameter_ends_in_buffer[sp_indices[-1]], - ) - else: - sp_begin, sp_end = 0, 0 - self._sequence_parallel_grads = self._grad_buffer[sp_begin:sp_end] if sp_end > sp_begin else None - - else: - Assert.none(grad_shard) - Assert.none(grad_buffer) + for fsdp in self._fsdps: + # TODO: Adjust + fsdp.setup( + mode=mode, + fsdp_group=self._distributed.data_group, + weight_shard=weight_shard, + grad_shard=grad_shard, + weight_buffer=weight_buffer, + grad_buffer=grad_buffer, + sequence_tensor_parallel=self._distributed_config.sequence_tensor_parallel, + ) if self._mode.support_forward: - # Precompute the buffer slice for each parameter. - # Use `.data` to hide the restore ops from autograd. - self._parameter_buffers = [] - for weight_buffer, grad_buffer, meta in zip( - self._split_buffer(self._weight_buffer.data), - self._split_buffer(self._grad_buffer if self._mode.support_backward else self._grad_buffer_meta), - self._parameter_metas, - ): - parameter_buffer = torch.nn.Parameter(weight_buffer, requires_grad=self._mode.support_backward) - if self._mode.support_backward: - parameter_buffer.grad_buffer = grad_buffer - # TODO: This is only needed for Megatron initialization - self._parameter_buffers.append(parameter_buffer) - # Replace the parameter definitions in each module with the actual parameter buffers. def _replace(module: torch.nn.Module): nonlocal i - for key in module._parameters: # noqa - meta = typing.cast(ParameterMeta, module._parameters[key]) # noqa - module._parameters[key] = self._parameter_buffers[self._parameter_index[meta.tensor_name]] # noqa + for key in module._parameters: + meta = typing.cast(ParameterMeta, module._parameters[key]) + module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 i = 0 @@ -260,17 +202,17 @@ def initialize_weights(self) -> None: log_generator("PP init generator before reset", self._distributed.pp_init_generator) log_generator("TP init generator before reset", self._distributed.tp_init_generator) - index = range(len(self._parameter_metas)) if self._distributed_config.reproducible_init: # Ensure a reproducible ordering. - index = sorted(index, key=lambda j: self._parameter_metas[j].tensor_name) - weight_shard_split = self._split_shard( - self._weight_shard if self._mode.on_device else self._weight_shard_meta - ) + sorted_metas = sorted(self._parameter_metas, key=lambda parameter_meta: parameter_meta.tensor_name) + weight_shards_split = [ + fsdp.split_shard(fsdp.weight_shard if fsdp.mode.on_device else fsdp.weight_shard_meta) + for fsdp in self._fsdps + ] - for i in index: - parameter = weight_shard_split[i] - meta = self._parameter_metas[i] + for meta in sorted_metas: + fsdp = self._fsdps[fsdp_index := self._fsdp_index[meta.tensor_name]] + parameter = weight_shards_split[fsdp_index][meta.tensor_name] # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape @@ -281,50 +223,30 @@ def initialize_weights(self) -> None: global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) if self._mode.on_device: - parameter.copy_(self._parameter_global_to_shard(global_param, i)) + parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: meta.init_parameter(parameter, self._distributed) if self.mode.on_device: - self.reset_shard_pad(self._weight_shard) + fsdp.reset_shard_pad(fsdp.weight_shard) if self._config.debug_param_init: log_generator("CPU generator after reset", torch.random.default_generator) log_generator("PP init generator after reset", self._distributed.pp_init_generator) log_generator("TP init generator after reset", self._distributed.tp_init_generator) if self._mode.on_device: - self.log_shard( + fsdp.log_shard( name="param", - shard=self._weight_shard, + shard=fsdp.weight_shard, + distributed=self._distributed, level=self._config.debug_param_init, + global_=self._config.debug_global_tensors, ) - def reset_shard_pad(self, shard: torch.Tensor) -> int: - assert self._is_setup - assert self._mode.on_device - # TODO: Needed? - # Prevent nans with the padded values - # Also ensures a correct parameter count in loading context. - self._weight_shard_meta.validate(shard) - if self._shard_pad > 0: - shard[-self._shard_pad :].zero_() - return self._shard_pad - return 0 - - def log_shard(self, name, shard, *, level, global_=None) -> None: - if global_ is None: - global_ = self._config.debug_global_tensors - parameters = self._split_buffer(self._reconstruct_from_shard(shard)) if global_ else self._split_shard(shard) - for parameter, meta in zip(parameters, self._parameter_metas): - log_distributed_tensor( - name, - parameter, - level=level, - distributed=self._distributed, - global_=global_, - duplicate_groups=(self._distributed.data_group,), - meta=meta, - ) + # def reset_shard_pad(self, shard: torch.Tensor) -> int: + # assert self._is_setup + # assert self._mode.on_device + # return sum(fsdp.reset_shard_pad(shard) for fsdp in self._fsdps) def get_param_groups( self, optimizer_state_shards: dict[str, torch.Tensor], param_group_cls: type[ParamGroup] @@ -336,25 +258,26 @@ def get_param_groups( # Get the weight slices and group by optimizer parameters, merging consecutive slices. grouped_parameter_slices = {} - for meta in self._parameter_metas: - # If needed, chunk the parameter on the first dimension. - chunk_size = div(meta.numel(), len(meta.lr_scale)) - param_index = self._parameter_index[meta.tensor_name] - buffer_begin = self._parameter_begins_in_buffer[param_index] - for i, lr_scale in enumerate(meta.lr_scale): - begin = self._index_buffer_to_shard(buffer_begin + i * chunk_size) - end = self._index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size) - if lr_scale == 0 or begin == end: - continue - optimizer_params = (meta.param_weight_decay, lr_scale) - if optimizer_params in grouped_parameter_slices: - last_slice = grouped_parameter_slices[optimizer_params][-1] - if begin == last_slice.stop: - grouped_parameter_slices[optimizer_params][-1] = slice(last_slice.start, end) + for fsdp in self._fsdps: + for parameter_name in fsdp.parameter_names: + # If needed, chunk the parameter on the first dimension. + parameter_meta = fsdp.get_parameter_meta(parameter_name) + chunk_size = div(parameter_meta.numel(), len(parameter_meta.lr_scale)) + buffer_begin = fsdp.parameter_begins_in_buffer[parameter_meta] + for i, lr_scale in enumerate(parameter_meta.lr_scale): + begin = fsdp.index_buffer_to_shard(buffer_begin + i * chunk_size) + end = fsdp.index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size) + if lr_scale == 0 or begin == end: continue - else: - grouped_parameter_slices[optimizer_params] = [] - grouped_parameter_slices[optimizer_params].append(slice(begin, end)) + optimizer_params = (parameter_meta.param_weight_decay, lr_scale) + if optimizer_params in grouped_parameter_slices: + last_slice = grouped_parameter_slices[optimizer_params][-1] + if begin == last_slice.stop: + grouped_parameter_slices[optimizer_params][-1] = slice(last_slice.start, end) + continue + else: + grouped_parameter_slices[optimizer_params] = [] + grouped_parameter_slices[optimizer_params].append(slice(begin, end)) param_groups = [ param_group_cls( @@ -372,77 +295,33 @@ def get_param_groups( ] # Get the weight slices to use for grad norm computation, merging consecutive slices. - grad_norm_indices = ( - list(range(len(self._parameter_metas))) - if self._distributed_config.tensor_rank == 0 - else [i for i, meta in enumerate(self._parameter_metas) if meta.is_tensor_parallel] - ) - grads_norm_slices = [] - for i in grad_norm_indices: - begin, end = self._parameter_range_in_shard(i) - if len(grads_norm_slices) < 0 and begin == grads_norm_slices[-1].stop: - grads_norm_slices[-1] = slice(grads_norm_slices[-1].start, end) - else: - grads_norm_slices.append(slice(begin, end)) - grads_for_norm = [self._grad_shard[slice_] for slice_ in grads_norm_slices] + grads_for_norm = [] + for fsdp in self._fsdps: + grad_norm_names = ( + fsdp.parameter_names + if self._distributed_config.tensor_rank == 0 + else [name for name in fsdp.parameter_names if fsdp.get_parameter_meta(name).is_tensor_parallel] + ) + grads_norm_slices = [] + for name in grad_norm_names: + begin, end = fsdp._parameter_range_in_shard(name) + if len(grads_norm_slices) < 0 and begin == grads_norm_slices[-1].stop: + grads_norm_slices[-1] = slice(grads_norm_slices[-1].start, end) + else: + grads_norm_slices.append(slice(begin, end)) + grads_for_norm += [fsdp.grad_shard[slice_] for slice_ in grads_norm_slices] return param_groups, grads_for_norm def check_tensor_parallel_synchronization(self) -> None: # TODO: Option to check the optimizer state. - for name, shard in zip(("grad", "weight"), (self.grad_shard, self.weight_shard)): - for meta, shard_slice in zip(self._parameter_metas, self._split_shard(shard)): - if shard_slice.numel() > 0 and not meta.is_tensor_parallel: - check_parallel_match(shard_slice, self._distributed.tensor_group, f"{name} {meta.tensor_name}") - - def _get_parameter_shard_indices_in_full_weight(self, name: str, device: torch.device) -> torch.Tensor: - """ - Create an index array for the global parameter, where each entry corresponds to the index - where it is located in the shard if it exists, or -1 if it's not in the shard. - Used to determine the location of each entry in a different distributed configuration. - """ - Assert.incl(name, self._parameter_index) - parameter_index = self._parameter_index[name] - parameter_meta = self._parameter_metas[parameter_index] - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) - # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard - begin, end = self._parameter_range_in_shard(parameter_index) - self._parameter_global_to_shard(index, parameter_index).copy_( - torch.arange(begin, end, dtype=torch.int64, device=device) - ) - return index - - def _copy_shard_overlaps( - self, - loaded_stage: "StageBase", - shards: list[torch.Tensor], - loaded_shards: list[torch.Tensor], - counter: torch.Tensor, - ) -> None: - """ - See MultiStage._load_partial. - """ - index_overlap = [name for name in loaded_stage._parameter_index if name in self._parameter_index] - for name in index_overlap: - self_index = self._parameter_index[name] - overlap_index_map = self._parameter_global_to_shard( - loaded_stage._get_parameter_shard_indices_in_full_weight(name, self._distributed.device), self_index - ) - overlap_mask = overlap_index_map >= 0 - overlap_index_map_masked = overlap_index_map[overlap_mask] - overlap_count = overlap_mask.sum() - begin, end = self._parameter_range_in_shard(self_index) - - for shard, loaded_shard in zip(shards, loaded_shards): - shard[begin:end][overlap_mask] = loaded_shard[overlap_index_map_masked] - counter += overlap_count + for fsdp in self._fsdps: + for shard_name, shard in zip(("grad", "weight"), (fsdp.grad_shard, fsdp.weight_shard)): + for parameter_name, shard_slice in fsdp.split_shard(shard).items(): + if shard_slice.numel() > 0 and not fsdp.get_parameter_meta(parameter_name).is_tensor_parallel: + check_parallel_match( + shard_slice, self._distributed.tensor_group, f"{shard_name} {parameter_name}" + ) def import_state_tensor( self, parameter_name: str, shard: torch.Tensor, tensor: torch.Tensor | SafeTensorSlice @@ -450,55 +329,37 @@ def import_state_tensor( """ Given a global parameter tensor, set the associated slice of a local parameter shard. Return the size of the local slice. + TODO: Doesn't work """ - Assert.eq(shard.shape, (self._shard_size,)) - parameter_index = self._parameter_index[parameter_name] - tensor_shard = self._parameter_global_to_shard(tensor, parameter_index) - begin, end = self._parameter_range_in_shard(parameter_index) - Assert.eq(tensor_shard.numel(), end - begin) - shard[begin:end].copy_(tensor_shard) - return end - begin + return self._fsdps[self._fsdp_index[parameter_name]].import_state_tensor(parameter_name, shard, tensor) def _export_shard( self, shard: torch.Tensor, data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: - if data_type is not None: - shard = shard.to(dtype=data_type.torch) - tensors = self._split_buffer(self._reconstruct_from_shard(shard)) - for name, param_index in self._parameter_index.items(): - yield name, self._parameter_metas[param_index].local_to_global( - tensors[param_index], distributed=self._distributed - )[0] - - def _parameter_range_in_shard(self, param_index: int) -> tuple[int, int]: - begin = self._index_buffer_to_shard(self._parameter_begins_in_buffer[param_index]) - end = self._index_buffer_to_shard(self._parameter_ends_in_buffer[param_index]) - return begin, end - - def _parameter_global_to_shard( - self, global_param: torch.Tensor | SafeTensorSlice, param_index: int - ) -> torch.Tensor: - shard_param = self._parameter_metas[param_index].global_to_local(global_param).flatten() - if self._fsdp_size > 1: - shard_param = shard_param[ - self._index_buffer_to_param( - self._fsdp_rank * self._shard_size, param_index - ) : self._index_buffer_to_param((self._fsdp_rank + 1) * self._shard_size, param_index) - ] - return shard_param + # TODO: Doesn't work + yield from self._fsdps[i].export_shard(shard, self._distributed, data_type) - def _get_parameter_metas(self) -> list[ParameterMeta]: + def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, # then separate the parameters with and without weight decay, # and squeeze the non-tensor parallel and sequence parallel ones in the middle. # This allows running the optimizer, grad norm and sequence_parallel reduction on contiguous buffers. parameter_metas: list[ParameterMeta] = [] + frozen_metas: list[ParameterMeta] = [] + meta: ParameterMeta for layer in self._layers: for name, meta in layer.named_parameters(): Assert.custom(isinstance, meta, ParameterMeta) Assert.eq(meta.dtype, self._distributed_config.optimization_dtype.torch) - parameter_metas.append(meta) + if meta.lr_scale == 0 or not meta.requires_grad: + frozen_metas.append(meta) + else: + parameter_metas.append(meta) + + return self._reorder_parameter_metas(parameter_metas), self._reorder_parameter_metas(frozen_metas) + @classmethod + def _reorder_parameter_metas(cls, parameter_metas): reorder_index = sorted( range(len(parameter_metas)), key=lambda i: ( @@ -510,37 +371,3 @@ def _get_parameter_metas(self) -> list[ParameterMeta]: reordered_metas = [parameter_metas[i] for i in reorder_index] return reordered_metas - - def _index_buffer_to_shard(self, index: int, rank: int | None = None) -> int: - shard_begin = (self._fsdp_rank if rank is None else rank) * self._shard_size - return clamp(index - shard_begin, 0, self._shard_size - self._shard_pad) - - def _index_buffer_to_param(self, index: int, param_index: int) -> int: - return clamp( - index - self._parameter_begins_in_buffer[param_index], 0, self._parameter_metas[param_index].numel() - ) - - def _reconstruct_from_shard(self, local_shard: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor: - return gather_op(local_shard, group=self._fsdp_group, dim=0, out=out) - - def _split_buffer(self, buffer: torch.Tensor) -> list[torch.Tensor]: - # Split a buffer into appropriately shaped parameters. - return [ - buffer[begin:end].view(meta.shape) - for begin, end, meta in zip( - self._parameter_begins_in_buffer, - self._parameter_ends_in_buffer, - self._parameter_metas, - ) - ] - - def _split_shard(self, shard: torch.Tensor) -> list[torch.Tensor]: - # Split a shard into flat (possibly empty) parameter slices. - return [ - shard[self._index_buffer_to_shard(begin) : self._index_buffer_to_shard(end)] - for begin, end, meta in zip( - self._parameter_begins_in_buffer, - self._parameter_ends_in_buffer, - self._parameter_metas, - ) - ] From 71210677f2ba3a50f1179fe4167787cd6c434b8d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Mar 2025 17:50:41 -0400 Subject: [PATCH 02/16] wip --- fast_llm/engine/multi_stage/config.py | 3 +- fast_llm/engine/multi_stage/fsdp.py | 140 ++++++++++++++++++--- fast_llm/engine/multi_stage/multi_stage.py | 84 +++++++++---- fast_llm/engine/multi_stage/stage.py | 94 +++++--------- fast_llm/engine/multi_stage/stage_base.py | 106 ++++++---------- fast_llm/engine/schedule/runner.py | 34 ++--- fast_llm/tensor.py | 2 +- 7 files changed, 278 insertions(+), 185 deletions(-) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 8bd191cc..fd7ba645 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -71,7 +71,8 @@ class StageConfig(Config): hint=FieldHint.optional, ) store_frozen_weights_in_optimization_precision: bool = Field( - default=False, + # TODO: Implement and set default to False + default=True, desc="Store frozen weights in full precision even if not not needed." "Allows preserving the precision for saved checkpoints," " at the cost of memory and compute (copy) overheads.", diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index e4eab333..6cc3269b 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -1,6 +1,8 @@ import typing import torch +from torch._C._distributed_c10d import ReduceOp +from torch.distributed import all_reduce, reduce_scatter_tensor from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op @@ -9,6 +11,7 @@ from fast_llm.engine.distributed.config import DistributedDim from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageMode +from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta from fast_llm.utils import Assert, clamp, padded_cumsum @@ -16,8 +19,18 @@ class FSDP: _is_setup: bool = False + _is_restored: bool = False _fsdp_group: ProcessGroup + _weight_shard: torch.Tensor + _grad_shard: torch.Tensor + _weight_buffer: torch.Tensor + _grad_buffer: torch.Tensor + _sequence_parallel_grads: torch.Tensor + _weight_buffer_local_shard: torch.Tensor + _grad_buffer_local_shard: torch.Tensor + _parameter_buffers: dict[str, torch.nn.Parameter] + def __init__( self, name: str, @@ -34,6 +47,8 @@ def __init__( self._gradient_buffer_dtype = gradient_buffer_dtype self._optimization_dtype = optimization_dtype + self._requires_grad = any(parameter_meta.requires_grad for parameter_meta in self._parameter_metas.values()) + parameter_sizes = [meta.numel() for meta in self._parameter_metas.values()] self._parameter_count = sum(parameter_sizes) parameter_offsets = padded_cumsum(parameter_sizes).tolist() @@ -61,7 +76,6 @@ def __init__( # TODO: Use parallel_dim property instead? shard_dim = TensorDim("flat_shard", (self._parameter_count + self._global_pad) // self._fsdp_dim.size) - buffer_dim = TensorDim("flat_buffer", shard_dim.size * self._fsdp_dim.size) self._weight_shard_meta = TensorMeta.from_dims( (shard_dim,), @@ -74,13 +88,13 @@ def __init__( dtype=self._optimization_dtype.torch, ) self._weight_buffer_meta = TensorMeta.from_dims( - (buffer_dim,), + (TensorDim("weight_buffer", shard_dim.size * self._fsdp_dim.size),), tensor_name=f"{self._name}_weight_buffer", dtype=self._training_dtype.torch, ) self._grad_buffer_meta = TensorMeta.from_dims( - (buffer_dim,), - tensor_name=f"{self._name}_weight_buffer", + (TensorDim("grad_buffer", shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), + tensor_name=f"{self._name}_grad_buffer", dtype=self._gradient_buffer_dtype.torch, ) @@ -92,6 +106,40 @@ def parameter_names(self) -> list[str]: def parameter_count(self) -> int: return self._parameter_count + @property + def requires_grad(self) -> bool: + return self._requires_grad + + @property + def weight_shard_meta(self) -> TensorMeta: + return self._weight_shard_meta + + # @property + # def grad_shard_meta(self) -> TensorMeta: + # return self._grad_shard_meta + + @property + def weight_buffer_meta(self) -> TensorMeta: + return self._weight_buffer_meta + + @property + def grad_buffer_meta(self) -> TensorMeta: + return self._grad_buffer_meta + + @property + def weight_shard(self) -> torch.Tensor: + # TODO: Avoid this method (needed for tied weights broadcast) + assert self._is_setup + assert self._mode.support_forward + return self._weight_shard + + @property + def grad_shard(self) -> torch.Tensor: + # TODO: Avoid this method (needed for tied weights reduce) + assert self._is_setup + assert self._mode.support_backward + return self._grad_shard + def __contains__(self, parameter_name: str) -> bool: return parameter_name in self.parameter_names @@ -103,6 +151,12 @@ def get_parameter_buffer(self, parameter_name: str) -> torch.nn.Parameter: assert self._mode.support_forward return self._parameter_buffers[parameter_name] + def get_parameter_begin_in_buffer(self, parameter_name: str) -> int: + return self._parameter_begins_in_buffer[parameter_name] + + def get_parameter_end_in_buffer(self, parameter_name: str) -> int: + return self._parameter_ends_in_buffer[parameter_name] + def setup( self, mode: StageMode, @@ -190,7 +244,7 @@ def reset_shard_pad(self, shard: torch.Tensor) -> int: def split_buffer(self, buffer: torch.Tensor) -> dict[str, torch.Tensor]: # Split a buffer into appropriately shaped parameters. return { - name: buffer[self._parameter_begins_in_buffer[name] : self._parameter_ends_in_buffer[name]].view( + name: buffer[self.get_parameter_begin_in_buffer(name) : self.get_parameter_end_in_buffer(name)].view( meta.shape ) for name, meta in self._parameter_metas.items() @@ -200,20 +254,22 @@ def split_shard(self, shard: torch.Tensor) -> dict[str, torch.Tensor]: # Split a shard into flat (possibly empty) parameter slices. return { name: shard[ - self._index_buffer_to_shard(self._parameter_begins_in_buffer[name]) : self._parameter_ends_in_buffer[ - name - ] + self.index_buffer_to_shard( + self.get_parameter_begin_in_buffer(name) + ) : self.get_parameter_end_in_buffer(name) ] for name in self._parameter_metas } - def _index_buffer_to_shard(self, index: int, rank: int | None = None) -> int: + def index_buffer_to_shard(self, index: int, rank: int | None = None) -> int: shard_begin = (self._fsdp_dim.rank if rank is None else rank) * self._shard_size return clamp(index - shard_begin, 0, self._shard_size - self._shard_pad) def _index_buffer_to_param(self, index: int, parameter_name: str) -> int: return clamp( - index - self._parameter_begins_in_buffer[parameter_name], 0, self._parameter_metas[parameter_name].numel() + index - self.get_parameter_begin_in_buffer(parameter_name), + 0, + self._parameter_metas[parameter_name].numel(), ) def reconstruct_from_shard(self, local_shard: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor: @@ -227,7 +283,7 @@ def import_state_tensor( Return the size of the local slice. """ Assert.eq(shard.shape, (self._shard_size,)) - tensor_shard = self._parameter_global_to_shard(tensor, parameter_name) + tensor_shard = self.parameter_global_to_shard(tensor, parameter_name) begin, end = self._parameter_range_in_shard(parameter_name) Assert.eq(tensor_shard.numel(), end - begin) shard[begin:end].copy_(tensor_shard) @@ -257,12 +313,64 @@ def log_shard(self, name, shard, *, distributed: Distributed, level, global_: bo meta=self.get_parameter_meta(parameter_name), ) + def restore_parameters(self) -> None: + assert self._is_setup + assert self._mode.support_forward + # TODO: Allow partial FSDP + if not self._is_restored: + triton_copy(self._weight_shard, self._weight_buffer_local_shard) + if self._fsdp_dim.size > 1: + self.reconstruct_from_shard(self._weight_buffer_local_shard, self._weight_buffer) + self._is_restored = True + + def reset_gradients(self) -> None: + # TODO: Allow re-allocating the gradient every time. + assert self._is_setup + assert self._mode.support_backward + if not self._requires_grad: + return + for buffer in self._parameter_buffers.values(): + assert buffer.grad is None + buffer.param_grad_is_zero = True + + def reduce_gradients(self, accumulate=False) -> None: + # Reduce the buffer, then copy (add) to actual grads. + # Run in a separate cuda stream to allow communication overlap. + # TODO: Allow partial FSDP + assert self._is_restored + assert self._mode.support_backward + if not self._requires_grad: + return + for buffer, meta in zip(self._parameter_buffers, self._parameter_metas.values()): + if buffer.param_grad_is_zero: # noqa + assert self.is_tied_weight_copy or meta.allow_no_grad, meta + triton_fill(buffer.grad_buffer, 0) # noqa + if self._sequence_parallel_grads is not None and self._distributed.tensor_group: + all_reduce(self._sequence_parallel_grads, group=self._distributed.tensor_group) + if self._fsdp_dim.size > 1: + full_precision_gradients = self._grad_buffer_local_shard.dtype == self._grad_shard.dtype + out = self._grad_shard if full_precision_gradients else self._grad_buffer_local_shard + if accumulate: + out = torch.empty_like(out) + reduce_scatter_tensor( + out, + self._grad_buffer, + group=self._fsdp_group, + op=ReduceOp.AVG, + ) + if accumulate: + triton_add(self._grad_shard, out, self._grad_shard) + elif not full_precision_gradients: + triton_copy(self._grad_buffer_local_shard, self._grad_shard) + else: + triton_copy(self._grad_buffer_local_shard, self._grad_shard) + def _parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: - begin = self._index_buffer_to_shard(self._parameter_begins_in_buffer[parameter_name]) - end = self._index_buffer_to_shard(self._parameter_ends_in_buffer[parameter_name]) + begin = self.index_buffer_to_shard(self.get_parameter_begin_in_buffer(parameter_name)) + end = self.index_buffer_to_shard(self.get_parameter_end_in_buffer(parameter_name)) return begin, end - def _parameter_global_to_shard( + def parameter_global_to_shard( self, global_param: torch.Tensor | SafeTensorSlice, parameter_name: str ) -> torch.Tensor: shard_param = self.get_parameter_meta(parameter_name).global_to_local(global_param).flatten() @@ -291,7 +399,7 @@ def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, devic ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._parameter_range_in_shard(parameter_name) - self._parameter_global_to_shard(index, parameter_name).copy_( + self.parameter_global_to_shard(index, parameter_name).copy_( torch.arange(begin, end, dtype=torch.int64, device=device) ) return index @@ -310,7 +418,7 @@ def _copy_shard_overlaps( """ index_overlap = [name for name in loaded_fsdp._parameter_metas if name in self._parameter_metas] for name in index_overlap: - overlap_index_map = self._parameter_global_to_shard( + overlap_index_map = self.parameter_global_to_shard( loaded_fsdp._get_parameter_shard_indices_in_full_weight(name, device), name ) overlap_mask = overlap_index_map >= 0 diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index ce5a14ac..c215b21f 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -119,17 +119,25 @@ def __init__( self._stage_shard_indices = { stage_index: shard_index for shard_index, stage_index in enumerate(self._stages_on_device) } - self._stage_shard_sizes = [stage.weight_shard_meta.numel() for stage in self._stages_on_device.values()] - stage_shard_dtype = get_unique([stage.weight_shard_meta.dtype for stage in self._stages_on_device.values()]) + self._fsdp_shard_sizes = [ + [fsdp.weight_shard_meta.numel() for fsdp in stage.fsdps] for stage in self._stages_on_device.values() + ] + self._stage_shard_sizes = [sum(shard_sizes) for shard_sizes in self._fsdp_shard_sizes] + # TODO: Support non-unique data type. + stage_shard_dtype = get_unique( + [fsdp.weight_shard_meta.dtype for stage in self._stages_on_device.values() for fsdp in stage.fsdps] + ) self._state_shard_names = ("weights",) + optimizer_state_names shard_dim = TensorDim("flat_shard", sum(self._stage_shard_sizes)) + # TODO: Avoid unnecessary shards (frozen weights or shard identical to buffer) self._weight_shard_meta = TensorMeta.from_dims( (shard_dim,), tensor_name=f"multi_stage_weight_shard", dtype=stage_shard_dtype, ) + # TODO !!!!!!: Remove for frozen weights. self._state_shard_meta = TensorMeta.from_dims( (TensorDim("state_shards", self.num_state_shards), shard_dim), tensor_name=f"multi_stage_state_shard", @@ -151,11 +159,17 @@ def __init__( ) if self._verbose: log_model_parallel_main_rank(f"Weight buffer placement:\n{self._weight_buffer_indices}") + self._fsdp_weight_buffer_sizes = [ + [fsdp.weight_buffer_meta.numel() for fsdp in stage.fsdps] for stage in self._stages + ] + self._stage_weight_buffer_sizes = [sum(buffer_sizes) for buffer_sizes in self._fsdp_weight_buffer_sizes] self._weight_buffer_sizes = [ - max(self._stages[stage_index].weight_buffer_meta.numel() for stage_index in contents) + max(self._stage_weight_buffer_sizes[stage_index] for stage_index in contents) for contents in self._weight_buffer_contents ] - weight_buffer_dtype = get_unique([stage.weight_buffer_meta.dtype for stage in self._stages]) + weight_buffer_dtype = get_unique( + [fsdp.weight_buffer_meta.dtype for stage in self._stages for fsdp in stage.fsdps] + ) self._weight_buffer_meta = TensorMeta.from_dims( (TensorDim("weight_buffer", sum(self._weight_buffer_sizes)),), tensor_name=f"multi_stage_weight_buffer", @@ -167,11 +181,16 @@ def __init__( ) if self._verbose: log_model_parallel_main_rank(f"Grad buffer placement:\n{self._grad_buffer_indices}") + self._fsdp_grad_buffer_sizes = [ + [fsdp.grad_buffer_meta.numel() for fsdp in stage.fsdps] for stage in self._stages + ] + self._stage_grad_buffer_sizes = [sum(buffer_sizes) for buffer_sizes in self._fsdp_grad_buffer_sizes] self._grad_buffer_sizes = [ - max(self._stages[stage_index].grad_buffer_meta.numel() for stage_index in contents) + max(self._stage_grad_buffer_sizes[stage_index] for stage_index in contents) for contents in self._grad_buffer_contents ] - grad_buffer_dtype = get_unique([stage.grad_buffer_meta.dtype for stage in self._stages]) + + grad_buffer_dtype = get_unique([fsdp.grad_buffer_meta.dtype for stage in self._stages for fsdp in stage.fsdps]) self._grad_buffer_meta = TensorMeta.from_dims( (TensorDim("grad_buffer", sum(self._grad_buffer_sizes)),), tensor_name=f"multi_stage_grad_buffer", @@ -251,21 +270,27 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) shard_index = self._stage_shard_indices.get(stage_index) weight_buffer_index = self._weight_buffer_indices.get(stage_index) grad_buffer_index = self._grad_buffer_indices.get(stage_index) - weight_buffer = ( - weight_buffers[weight_buffer_index][: stage.weight_buffer_meta.numel()] # noqa + weight_buffers = ( + weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[weight_buffer_index]].split( + self._fsdp_weight_buffer_sizes[weight_buffer_index] + ) if self._mode.support_forward and weight_buffer_index is not None else None ) - grad_buffer = ( - grad_buffers[grad_buffer_index][: stage.grad_buffer_meta.numel()] # noqa + grad_buffers = ( + grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[grad_buffer_index]].split( + self._fsdp_grad_buffer_sizes[grad_buffer_index] + ) if self._mode.support_backward and grad_buffer_index is not None else None ) - weight_shard = ( - weight_shard_split[shard_index] if self._mode.on_device and shard_index is not None else None # noqa + weight_shards = ( + weight_shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) + if self._mode.on_device and shard_index is not None + else None # noqa ) - grad_shard = ( - grad_shard_split[shard_index] # noqa + grad_shards = ( + grad_shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) # noqa if self._mode.support_backward and shard_index is not None else None ) @@ -276,10 +301,10 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) ) stage.setup( distributed=distributed, - weight_shard=weight_shard, - grad_shard=grad_shard, - weight_buffer=weight_buffer, - grad_buffer=grad_buffer, + weight_shards=weight_shards, + grad_shards=grad_shards, + weight_buffers=weight_buffers, + grad_buffers=grad_buffers, mode=self._mode if stage_index in self._stages_on_device else StageMode.off_device, is_tied_weight_copy=stage_index in self._stages_on_device and stage_index not in self._stages_owned, weight_buffer_shared_with=weight_buffer_shared_with, @@ -296,8 +321,9 @@ def get_param_groups( optimizer_shards_split = [shard.split(self._stage_shard_sizes) for shard in self._optimizer_shard.unbind()] param_groups, grads_for_norm = [], [] for stage_index, stage in self._stages_on_device.items(): + shard_index = self._stage_shard_indices.get(stage_index) stage_optimizer_shards = { - name: shard_split[self._stage_shard_indices[stage_index]] + name: shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) for name, shard_split in zip(self._state_shard_names[1:], optimizer_shards_split) } stage_param_groups, stage_grads_for_norm = stage.get_param_groups( @@ -403,8 +429,12 @@ def get_state_tensor_iterator( ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for i, shard_name in enumerate(shard_names): shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) - for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, data_type=data_type): # noqa + for shard_index, (stage, shard) in enumerate( + zip(self._stages_on_device.values(), shard_split, strict=True) + ): + for name, tensor in stage._export_shard( + shard.split(self._fsdp_shard_sizes[shard_index]), data_type=data_type + ): # noqa yield name, shard_name, tensor def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torch.Tensor | SafeTensorSlice): @@ -415,11 +445,13 @@ def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torc if not self.is_parameter_on_device(parameter_name): # Parameter is not on device, nothing to do. return 0 - stage_index = self._stage_shard_indices[self._parameter_stages[parameter_name]] - stage_shard = self._state_shard[self._state_shard_names.index(shard_name)].split(self._stage_shard_sizes, 0)[ - stage_index - ] - return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shard, tensor) + shard_index = self._stage_shard_indices[self._parameter_stages[parameter_name]] + stage_shards = ( + self._state_shard[self._state_shard_names.index(shard_name)] + .split(self._stage_shard_sizes, 0)[shard_index] + .split(self._fsdp_shard_sizes[shard_index]) + ) + return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shards, tensor) def _split_into_stages(self) -> list[int]: # Create stages (greedy split, could do better). diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index f0ad904f..5bc42c3f 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -2,14 +2,12 @@ import typing import torch -from torch.distributed import all_reduce, reduce_scatter_tensor -from fast_llm.core.distributed import ReduceOp, check_parallel_match +from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.stage_base import StageBase -from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage, log_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient from fast_llm.utils import Assert @@ -40,20 +38,20 @@ def setup( # noqa self, *, distributed: Distributed, - weight_shard: torch.Tensor | None, - grad_shard: torch.Tensor | None, - weight_buffer: torch.Tensor | None, - grad_buffer: torch.Tensor | None, + weight_shards: torch.Tensor | None, + grad_shards: torch.Tensor | None, + weight_buffers: torch.Tensor | None, + grad_buffers: torch.Tensor | None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, weight_buffer_shared_with: list["Stage"], ) -> None: super().setup( distributed=distributed, - weight_shard=weight_shard, - grad_shard=grad_shard, - weight_buffer=weight_buffer, - grad_buffer=grad_buffer, + weight_shards=weight_shards, + grad_shards=grad_shards, + weight_buffers=weight_buffers, + grad_buffers=grad_buffers, mode=mode, ) self._is_tied_weight_copy = is_tied_weight_copy @@ -66,7 +64,8 @@ def setup( # noqa if self._mode.support_backward: self._accumulators = [] with torch.enable_grad(): - for buffer, meta in zip(self._parameter_buffers, self._parameter_metas): + for meta in self._parameter_metas: + buffer = self.get_parameter_buffer(meta.tensor_name) # We want to replace the grad accumulation function with ours, but pytorch won't let us do that. # Instead, we let a trivial accumulation run its course (sets .grad), # then run the actual accumulation. @@ -122,69 +121,40 @@ def restore_parameters(self) -> None: assert self._mode.support_forward # TODO: Allow partial FSDP if not self._is_restored: - triton_copy(self._weight_shard, self._weight_buffer_local_shard) - if self._fsdp_size > 1: - self._reconstruct_from_shard(self._weight_buffer_local_shard, self._weight_buffer) + for fsdp in self._fsdps: + fsdp.restore_parameters() self._is_restored = True for stage in self._weight_buffer_shared_with: stage.invalidate_buffer() def reset_gradients(self) -> None: # TODO: Allow re-allocating the gradient every time. - # TODO: Autograd will always increment gradient instead of setting the value (less efficient) - # Can this (and op below) be avoided? (Probably needs messing with autograd) - # Solution: set a zero_grad flag on parameter, then adjust backward fn to set or accumulate depending on flag. - # Then we can also avoid explicitly setting to zero. - # Logic implemented for linear and ln, missing embedding. assert self._is_setup assert self._mode.support_backward - # assert self._is_restored - for buffer in self._parameter_buffers: - assert buffer.grad is None - buffer.param_grad_is_zero = True + for fsdp in self._fsdps: + fsdp.reset_gradients() def reduce_gradients(self, accumulate=False) -> None: - # Just need to reduce the buffer, then copy (add) to actual grads. - # Works fine as is but does not allow communication overlap by itself. - # Reduction should only be done once per step, after the full backward pass is done for the stage. + # Reduce the buffer, then copy (add) to actual grads. + # Run in a separate cuda stream to allow communication overlap. # TODO: Allow partial FSDP assert self._is_restored assert self._mode.support_backward - for buffer, meta in zip(self._parameter_buffers, self._parameter_metas): - if buffer.param_grad_is_zero: # noqa - assert self.is_tied_weight_copy or meta.allow_no_grad, meta - triton_fill(buffer.grad_buffer, 0) # noqa - if self._sequence_parallel_grads is not None and self._distributed.tensor_group: - all_reduce(self._sequence_parallel_grads, group=self._distributed.tensor_group) - if self._fsdp_size > 1: - out = self._grad_shard if self._config.full_precision_gradients else self._grad_buffer_local_shard - if accumulate: - out = torch.empty_like(out) - reduce_scatter_tensor( - out, - self._grad_buffer, - group=self._fsdp_group, - op=ReduceOp.AVG, - ) - if accumulate: - triton_add(self._grad_shard, out, self._grad_shard) - elif not self._config.full_precision_gradients: - triton_copy(self._grad_buffer_local_shard, self._grad_shard) - else: - triton_copy(self._grad_buffer_local_shard, self._grad_shard) - if self._config.debug_param_gradients: - log_tensor( - "Reduced gradient shard", - self._grad_shard, - level=self._config.debug_param_gradients, - global_=False, - ) - if self._config.debug_all_param_gradients: - self.log_shard( - name="gradient", - shard=self._grad_shard, - level=self._config.debug_all_param_gradients, - ) + for fsdp in self._fsdps: + fsdp.reduce_gradients(accumulate) + if self._config.debug_param_gradients: + log_tensor( + "Reduced gradient shard", + fsdp.grad_shard, + level=self._config.debug_param_gradients, + global_=False, + ) + if self._config.debug_all_param_gradients: + fsdp.log_shard( + name="gradient", + shard=fsdp.grad_shard, + level=self._config.debug_all_param_gradients, + ) @property def is_tied_weight_copy(self) -> bool: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 0131b2a1..35b2a224 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -5,7 +5,7 @@ import torch._dynamo # noqa from fast_llm.config import Configurable -from fast_llm.core.distributed import ProcessGroup, check_parallel_match +from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -14,7 +14,7 @@ from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.logging import log_generator -from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta +from fast_llm.tensor import ParameterMeta, SafeTensorSlice from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) @@ -22,17 +22,14 @@ class StageBase(Configurable[StageConfig]): config_class: typing.ClassVar[type[StageConfig]] = StageConfig - _meta_inputs: list[TensorMeta] - _meta_outputs: list[TensorMeta] _distributed: Distributed - _fsdp_group: ProcessGroup _mode: StageMode # _weight_shard: torch.Tensor # _grad_shard: torch.Tensor # _weight_buffer: torch.Tensor # _grad_buffer: torch.Tensor - _sequence_parallel_grads: torch.Tensor + # _sequence_parallel_grads: torch.Tensor # _weight_buffer_local_shard: torch.Tensor # _grad_buffer_local_shard: torch.Tensor # _parameter_buffers: list[torch.nn.Parameter] @@ -105,35 +102,9 @@ def mode(self) -> StageMode: def index(self) -> int: return self._index - # @property - # def weight_shard_meta(self) -> TensorMeta: - # return self._weight_shard_meta - - # @property - # def grad_shard_meta(self) -> TensorMeta: - # return self._grad_shard_meta - - # @property - # def weight_buffer_meta(self) -> TensorMeta: - # return self._weight_buffer_meta - - # @property - # def grad_buffer_meta(self) -> TensorMeta: - # return self._grad_buffer_meta - - # @property - # def weight_shard(self) -> torch.Tensor: - # # TODO: Avoid this method (needed for tied weights broadcast) - # assert self._is_setup - # assert self._mode.support_forward - # return self._weight_shard - - # @property - # def grad_shard(self) -> torch.Tensor: - # # TODO: Avoid this method (needed for tied weights reduce) - # assert self._is_setup - # assert self._mode.support_backward - # return self._grad_shard + @property + def fsdps(self) -> list[FSDP]: + return self._fsdps @property def parameter_count(self) -> int: @@ -154,20 +125,20 @@ def setup( self, *, distributed: Distributed, - weight_shard: torch.Tensor | None, - grad_shard: torch.Tensor | None, - weight_buffer: torch.Tensor | None, - grad_buffer: torch.Tensor | None, + weight_shards: list[torch.Tensor | None], + grad_shards: list[torch.Tensor | None], + weight_buffers: list[torch.Tensor | None], + grad_buffers: list[torch.Tensor | None], mode: StageMode = StageMode.training, ) -> None: assert not self._is_setup assert distributed.config is self._distributed_config self._is_setup = True self._distributed = distributed - self._fsdp_group = self._distributed.data_group - for fsdp in self._fsdps: - # TODO: Adjust + for fsdp, weight_shard, grad_shard, weight_buffer, grad_buffer in zip( + self._fsdps, weight_shards, grad_shards, weight_buffers, grad_buffers, strict=True + ): fsdp.setup( mode=mode, fsdp_group=self._distributed.data_group, @@ -206,7 +177,7 @@ def initialize_weights(self) -> None: # Ensure a reproducible ordering. sorted_metas = sorted(self._parameter_metas, key=lambda parameter_meta: parameter_meta.tensor_name) weight_shards_split = [ - fsdp.split_shard(fsdp.weight_shard if fsdp.mode.on_device else fsdp.weight_shard_meta) + fsdp.split_shard(fsdp.weight_shard if self._mode.on_device else fsdp.weight_shard_meta) for fsdp in self._fsdps ] @@ -249,21 +220,25 @@ def initialize_weights(self) -> None: # return sum(fsdp.reset_shard_pad(shard) for fsdp in self._fsdps) def get_param_groups( - self, optimizer_state_shards: dict[str, torch.Tensor], param_group_cls: type[ParamGroup] + self, optimizer_state_shards: dict[str, tuple[torch.Tensor]], param_group_cls: type[ParamGroup] ) -> tuple[list[ParamGroup], list[torch.Tensor]]: # TODO: Separate model-specific code. # TODO: verify optimizer states assert self._is_setup assert self._mode.support_training + assert all(len(state_shards) == len(self._fsdps) for state_shards in optimizer_state_shards.values()) # Get the weight slices and group by optimizer parameters, merging consecutive slices. grouped_parameter_slices = {} - for fsdp in self._fsdps: + param_groups = [] + for i, fsdp in enumerate(self._fsdps): for parameter_name in fsdp.parameter_names: # If needed, chunk the parameter on the first dimension. parameter_meta = fsdp.get_parameter_meta(parameter_name) + if not parameter_meta.requires_grad: + continue chunk_size = div(parameter_meta.numel(), len(parameter_meta.lr_scale)) - buffer_begin = fsdp.parameter_begins_in_buffer[parameter_meta] + buffer_begin = fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name) for i, lr_scale in enumerate(parameter_meta.lr_scale): begin = fsdp.index_buffer_to_shard(buffer_begin + i * chunk_size) end = fsdp.index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size) @@ -279,20 +254,20 @@ def get_param_groups( grouped_parameter_slices[optimizer_params] = [] grouped_parameter_slices[optimizer_params].append(slice(begin, end)) - param_groups = [ - param_group_cls( - name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa - params=[self._weight_shard[slice_] for slice_ in slices], # noqa - grads=[self._grad_shard[slice_] for slice_ in slices], # noqa - **{ # noqa - name: [optimizer_state[slice_] for slice_ in slices] - for name, optimizer_state in optimizer_state_shards.items() - }, - weight_decay=None if weight_decay else 0.0, # noqa - lr_scale=lr_scale, # noqa - ) - for (weight_decay, lr_scale), slices in grouped_parameter_slices.items() - ] + param_groups += [ + param_group_cls( + name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa + params=[self._weight_shard[slice_] for slice_ in slices], # noqa + grads=[self._grad_shard[slice_] for slice_ in slices], # noqa + **{ # noqa + name: [optimizer_state[i][slice_] for slice_ in slices] + for name, optimizer_state in optimizer_state_shards.items() + }, + weight_decay=None if weight_decay else 0.0, # noqa + lr_scale=lr_scale, # noqa + ) + for (weight_decay, lr_scale), slices in grouped_parameter_slices.items() + ] # Get the weight slices to use for grad norm computation, merging consecutive slices. grads_for_norm = [] @@ -334,10 +309,11 @@ def import_state_tensor( return self._fsdps[self._fsdp_index[parameter_name]].import_state_tensor(parameter_name, shard, tensor) def _export_shard( - self, shard: torch.Tensor, data_type: DataType | None = None + self, shards: tuple[torch.Tensor], data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: # TODO: Doesn't work - yield from self._fsdps[i].export_shard(shard, self._distributed, data_type) + for fsdp, shard in zip(self._fsdps, shards, strict=True): + yield from fsdp.export_shard(shard, self._distributed, data_type) def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, @@ -351,10 +327,10 @@ def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta] for name, meta in layer.named_parameters(): Assert.custom(isinstance, meta, ParameterMeta) Assert.eq(meta.dtype, self._distributed_config.optimization_dtype.torch) - if meta.lr_scale == 0 or not meta.requires_grad: - frozen_metas.append(meta) - else: + if meta.requires_grad: parameter_metas.append(meta) + else: + frozen_metas.append(meta) return self._reorder_parameter_metas(parameter_metas), self._reorder_parameter_metas(frozen_metas) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index d4b1da10..0399116a 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -223,14 +223,17 @@ def run_step( # Stage hasn't been reduced yet. # TODO: Overlap this? (reduce with last local layer that uses it) main_stage.reduce_gradients() - # TODO: Overlap this? (not really useful for gpt) - all_reduce(main_stage.grad_shard, group=tied_parameter.group) - if self._multi_stage.config.multi_stage.debug_all_param_gradients: - main_stage.log_shard( - name="gradient", - shard=main_stage.grad_shard, - level=self._multi_stage.config.multi_stage.debug_all_param_gradients, - ) + for fsdp in main_stage.fsdps: + # TODO: Overlap this? (not really useful for gpt) + all_reduce(fsdp.grad_shard, group=tied_parameter.group) + if self._multi_stage.config.multi_stage.debug_all_param_gradients: + fsdp.log_shard( + name="gradient", + shard=fsdp.grad_shard, + distributed=self._distributed, + level=self._multi_stage.config.multi_stage.debug_all_param_gradients, + global_=self._multi_stage.config.multi_stage.debug_global_tensors, + ) self._record_event(context, EventType.post_reduce, None) # Update weights @@ -248,11 +251,14 @@ def run_step( stage.invalidate_buffer() if self._multi_stage.config.multi_stage.debug_param_update: for stage in self._stages_on_device: - stage.log_shard( - name="param", - shard=stage.weight_shard, - level=self._multi_stage.config.multi_stage.debug_param_update, - ) + for fsdp in stage.fsdps: + fsdp.log_shard( + name="param", + shard=fsdp.weight_shard, + distributed=self._distributed, + level=self._multi_stage.config.multi_stage.debug_param_update, + global_=self._multi_stage.config.multi_stage.debug_global_tensors, + ) self._record_event(context, EventType.optimizer, None) self._record_event(context, EventType.batch_end, None) @@ -335,7 +341,7 @@ def _preprocess_data( for name, tied_parameter in self._tied_parameters.items(): if tied_parameter.on_device: kwargs[name] = self._stages[tied_parameter.main_stage].get_parameter_buffer( - tied_parameter.meta + tied_parameter.meta.tensor_name ) data_index = context.schedule.get_data_index(micro_batch, micro_sequence) if self._stages_owned[0]: diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 608eb622..f59927b6 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -222,7 +222,6 @@ def __init__( self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False - self.requires_grad = requires_grad # Almost all parameters are either tensor-parallel or process tensor-sequence-parallel inputs. # Except for position embedding weights self.sequence_tensor_parallel = allow_sequence_tensor_parallel and not self.is_tensor_parallel @@ -234,6 +233,7 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) + self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From a06d678b6c3acfec7cf175dae30f0a3de9026de6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 13 Mar 2025 17:53:52 -0400 Subject: [PATCH 03/16] fixes --- fast_llm/engine/checkpoint/distributed.py | 30 +++++++---- fast_llm/engine/checkpoint/safe_load.py | 58 +++++++++++++--------- fast_llm/engine/multi_stage/fsdp.py | 2 +- fast_llm/engine/multi_stage/multi_stage.py | 16 +++--- fast_llm/engine/multi_stage/stage.py | 10 ++-- fast_llm/engine/multi_stage/stage_base.py | 27 +++++++--- 6 files changed, 89 insertions(+), 54 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 9c171bef..7c0b9a26 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -93,14 +93,26 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device) for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()): - loaded_shards = ( - loaded_shard_split[loaded_shard_index].to(self._model.distributed.device).unbind(0) + loaded_stage_shards = loaded_shard_split[loaded_shard_index].to( + self._model.distributed.device ) - for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()): - self_stage._copy_shard_overlaps( # noqa - loaded_stage, - self_shard_split[self_shard_index].unbind(0), - loaded_shards, - counter, - ) + for loaded_fsdp, loaded_fsdp_shards in zip( + loaded_stage.fsdps, + loaded_stage_shards.split(loaded_model._fsdp_shard_sizes[loaded_shard_index], 1), + strict=True, + ): + for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()): + self_stage_shards = self_shard_split[self_shard_index] + for self_fsdp, self_fsdp_shards in zip( + self_stage.fsdps, + self_stage_shards.split(self._model._fsdp_shard_sizes[self_shard_index], 1), + strict=True, + ): + self_fsdp._copy_shard_overlaps( # noqa + loaded_fsdp, + self_fsdp_shards.unbind(0), + loaded_fsdp_shards, + counter, + self._model.distributed.device, + ) context.mark_as_loaded(counter.item()) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 4cf9263b..8a7701a7 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -40,8 +40,13 @@ def __enter__(self) -> "SafeLoad": # Reset and count shard pads for shard in self._model.state_shard[: self._num_shards]: shard_split = shard.split(self._model.stage_shard_sizes, 0) - for stage, stage_shard in zip(self._model.stages_on_device.values(), shard_split): - self._loaded += stage.reset_shard_pad(stage_shard) + for shard_index, (stage, stage_shard) in enumerate( + zip(self._model.stages_on_device.values(), shard_split) + ): + for fsdp, fsdp_shard in zip( + stage.fsdps, stage_shard.split(self._model._fsdp_shard_sizes[shard_index]), strict=True + ): + self._loaded += fsdp.reset_shard_pad(fsdp_shard) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -92,30 +97,35 @@ def _check_missing(self, errors: list[str]) -> None: global_total, local_total = 0, 0 for shard_name, shard_ in zip(self._model.state_shard_names[: self._num_shards], self._self_shard): shard_split = shard_.split(self._model.stage_shard_sizes, 0) - for stage, shard in zip(self._model.stages_on_device.values(), shard_split): - buffer = stage._reconstruct_from_shard(shard) - for i, parameter in enumerate(stage._split_buffer(buffer)): - missing_for_param = parameter.isnan().sum().item() - if missing_for_param > 0: - global_total += missing_for_param - local_values = stage._split_shard(shard)[i] - local_missing_for_param = local_values.isnan().sum().item() - local_total += local_missing_for_param + for shard_index, (stage, stage_shard) in enumerate( + zip(self._model.stages_on_device.values(), shard_split) + ): + for fsdp, fsdp_shard in zip( + stage.fsdps, stage_shard.split(self._model._fsdp_shard_sizes[shard_index]), strict=True + ): + buffer = fsdp.reconstruct_from_shard(fsdp_shard) + for parameter_name, parameter in fsdp.split_buffer(buffer).items(): + missing_for_param = parameter.isnan().sum().item() + if missing_for_param > 0: + global_total += missing_for_param + local_values = fsdp.split_shard(fsdp_shard)[parameter_name] + local_missing_for_param = local_values.isnan().sum().item() + local_total += local_missing_for_param + errors.append( + f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}" + f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" + ) + missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() + if missing_for_pad > 0: + global_total += missing_for_pad + local_missing_for_pad = ( + fsdp_shard[-fsdp._shard_pad :].isnan().sum().item() if fsdp._shard_pad > 0 else 0 + ) + local_total += local_missing_for_pad errors.append( - f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {stage.parameter_names[i]} in stage {stage.index}, shard {shard_name}" - f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" + f"{missing_for_pad:,} values missing out of {fsdp._global_pad:,} for padding in stage {stage.index}, shard {shard_name}" + f" (locally {local_missing_for_pad:,} out of {fsdp._shard_pad:,})" ) - missing_for_pad = buffer[-stage._global_pad :].isnan().sum().item() - if missing_for_pad > 0: - global_total += missing_for_pad - local_missing_for_pad = ( - shard[-stage._shard_pad :].isnan().sum().item() if stage._shard_pad > 0 else 0 - ) - local_total += local_missing_for_pad - errors.append( - f"{missing_for_pad:,} values missing out of {stage._global_pad:,} for padding in stage {stage.index}, shard {shard_name}" - f" (locally {local_missing_for_pad:,} out of {stage._shard_pad:,})" - ) if global_total != global_missing: errors.append( f"Incorrect global breakdown of missing state entries (expected {global_missing:,}, got {global_total:,})" diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 6cc3269b..d26dd77a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -341,7 +341,7 @@ def reduce_gradients(self, accumulate=False) -> None: assert self._mode.support_backward if not self._requires_grad: return - for buffer, meta in zip(self._parameter_buffers, self._parameter_metas.values()): + for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values()): if buffer.param_grad_is_zero: # noqa assert self.is_tied_weight_copy or meta.allow_no_grad, meta triton_fill(buffer.grad_buffer, 0) # noqa diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index c215b21f..2782444a 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -270,26 +270,26 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) shard_index = self._stage_shard_indices.get(stage_index) weight_buffer_index = self._weight_buffer_indices.get(stage_index) grad_buffer_index = self._grad_buffer_indices.get(stage_index) - weight_buffers = ( + stage_weight_buffers = ( weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[weight_buffer_index]].split( self._fsdp_weight_buffer_sizes[weight_buffer_index] ) if self._mode.support_forward and weight_buffer_index is not None else None ) - grad_buffers = ( + stage_grad_buffers = ( grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[grad_buffer_index]].split( self._fsdp_grad_buffer_sizes[grad_buffer_index] ) if self._mode.support_backward and grad_buffer_index is not None else None ) - weight_shards = ( + stage_weight_shards = ( weight_shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) if self._mode.on_device and shard_index is not None else None # noqa ) - grad_shards = ( + stage_grad_shards = ( grad_shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) # noqa if self._mode.support_backward and shard_index is not None else None @@ -301,10 +301,10 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) ) stage.setup( distributed=distributed, - weight_shards=weight_shards, - grad_shards=grad_shards, - weight_buffers=weight_buffers, - grad_buffers=grad_buffers, + weight_shards=stage_weight_shards, + grad_shards=stage_grad_shards, + weight_buffers=stage_weight_buffers, + grad_buffers=stage_grad_buffers, mode=self._mode if stage_index in self._stages_on_device else StageMode.off_device, is_tied_weight_copy=stage_index in self._stages_on_device and stage_index not in self._stages_owned, weight_buffer_shared_with=weight_buffer_shared_with, diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 5bc42c3f..78b336c9 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -38,10 +38,10 @@ def setup( # noqa self, *, distributed: Distributed, - weight_shards: torch.Tensor | None, - grad_shards: torch.Tensor | None, - weight_buffers: torch.Tensor | None, - grad_buffers: torch.Tensor | None, + weight_shards: list[torch.Tensor | None] | None, + grad_shards: list[torch.Tensor | None] | None, + weight_buffers: list[torch.Tensor | None] | None, + grad_buffers: list[torch.Tensor | None] | None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, weight_buffer_shared_with: list["Stage"], @@ -153,7 +153,9 @@ def reduce_gradients(self, accumulate=False) -> None: fsdp.log_shard( name="gradient", shard=fsdp.grad_shard, + distributed=self._distributed, level=self._config.debug_all_param_gradients, + global_=self._config.debug_global_tensors, ) @property diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 35b2a224..277c3623 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -125,17 +125,27 @@ def setup( self, *, distributed: Distributed, - weight_shards: list[torch.Tensor | None], - grad_shards: list[torch.Tensor | None], - weight_buffers: list[torch.Tensor | None], - grad_buffers: list[torch.Tensor | None], + weight_shards: list[torch.Tensor | None] | None, + grad_shards: list[torch.Tensor | None] | None, + weight_buffers: list[torch.Tensor | None] | None, + grad_buffers: list[torch.Tensor | None] | None, mode: StageMode = StageMode.training, ) -> None: assert not self._is_setup assert distributed.config is self._distributed_config + self._mode = mode self._is_setup = True self._distributed = distributed + if weight_shards is None: + weight_shards = [None for _ in self._fsdps] + if grad_shards is None: + grad_shards = [None for _ in self._fsdps] + if weight_buffers is None: + weight_buffers = [None for _ in self._fsdps] + if grad_buffers is None: + grad_buffers = [None for _ in self._fsdps] + for fsdp, weight_shard, grad_shard, weight_buffer, grad_buffer in zip( self._fsdps, weight_shards, grad_shards, weight_buffers, grad_buffers, strict=True ): @@ -257,8 +267,8 @@ def get_param_groups( param_groups += [ param_group_cls( name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa - params=[self._weight_shard[slice_] for slice_ in slices], # noqa - grads=[self._grad_shard[slice_] for slice_ in slices], # noqa + params=[fsdp.weight_shard[slice_] for slice_ in slices], # noqa + grads=[fsdp.grad_shard[slice_] for slice_ in slices], # noqa **{ # noqa name: [optimizer_state[i][slice_] for slice_ in slices] for name, optimizer_state in optimizer_state_shards.items() @@ -299,14 +309,15 @@ def check_tensor_parallel_synchronization(self) -> None: ) def import_state_tensor( - self, parameter_name: str, shard: torch.Tensor, tensor: torch.Tensor | SafeTensorSlice + self, parameter_name: str, shards: tuple[torch.Tensor], tensor: torch.Tensor | SafeTensorSlice ) -> int: """ Given a global parameter tensor, set the associated slice of a local parameter shard. Return the size of the local slice. TODO: Doesn't work """ - return self._fsdps[self._fsdp_index[parameter_name]].import_state_tensor(parameter_name, shard, tensor) + fsdp_index = self._fsdp_index[parameter_name] + return self._fsdps[fsdp_index].import_state_tensor(parameter_name, shards[fsdp_index], tensor) def _export_shard( self, shards: tuple[torch.Tensor], data_type: DataType | None = None From 863bcf7062e1635feb86dfad7afcce953914ca40 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 13 Mar 2025 19:16:39 -0400 Subject: [PATCH 04/16] fixes --- fast_llm/engine/multi_stage/multi_stage.py | 8 ++++---- fast_llm/engine/multi_stage/stage_base.py | 11 +++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 2782444a..8e56c5bb 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -271,15 +271,15 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) weight_buffer_index = self._weight_buffer_indices.get(stage_index) grad_buffer_index = self._grad_buffer_indices.get(stage_index) stage_weight_buffers = ( - weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[weight_buffer_index]].split( - self._fsdp_weight_buffer_sizes[weight_buffer_index] + weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[stage_index]].split( + self._fsdp_weight_buffer_sizes[stage_index] ) if self._mode.support_forward and weight_buffer_index is not None else None ) stage_grad_buffers = ( - grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[grad_buffer_index]].split( - self._fsdp_grad_buffer_sizes[grad_buffer_index] + grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[stage_index]].split( + self._fsdp_grad_buffer_sizes[stage_index] ) if self._mode.support_backward and grad_buffer_index is not None else None diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 277c3623..43d01eec 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -183,15 +183,18 @@ def initialize_weights(self) -> None: log_generator("PP init generator before reset", self._distributed.pp_init_generator) log_generator("TP init generator before reset", self._distributed.tp_init_generator) - if self._distributed_config.reproducible_init: - # Ensure a reproducible ordering. - sorted_metas = sorted(self._parameter_metas, key=lambda parameter_meta: parameter_meta.tensor_name) + # Ensure a reproducible ordering. + metas = ( + sorted(self._parameter_metas, key=lambda parameter_meta: parameter_meta.tensor_name) + if self._distributed_config.reproducible_init + else self._parameter_metas + ) weight_shards_split = [ fsdp.split_shard(fsdp.weight_shard if self._mode.on_device else fsdp.weight_shard_meta) for fsdp in self._fsdps ] - for meta in sorted_metas: + for meta in metas: fsdp = self._fsdps[fsdp_index := self._fsdp_index[meta.tensor_name]] parameter = weight_shards_split[fsdp_index][meta.tensor_name] # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) From 1926da1a78486c216b0e61a767708bbfe5ee938a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 13 Mar 2025 20:01:49 -0400 Subject: [PATCH 05/16] fix --- fast_llm/engine/multi_stage/fsdp.py | 6 +++--- tests/common.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index d26dd77a..b3cd4e2f 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -254,9 +254,9 @@ def split_shard(self, shard: torch.Tensor) -> dict[str, torch.Tensor]: # Split a shard into flat (possibly empty) parameter slices. return { name: shard[ - self.index_buffer_to_shard( - self.get_parameter_begin_in_buffer(name) - ) : self.get_parameter_end_in_buffer(name) + self.index_buffer_to_shard(self.get_parameter_begin_in_buffer(name)) : self.index_buffer_to_shard( + self.get_parameter_end_in_buffer(name) + ) ] for name in self._parameter_metas } diff --git a/tests/common.py b/tests/common.py index 6cec64e1..8b8e57c3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -15,9 +15,9 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.tools.train import CliTrainingConfig @@ -62,8 +62,10 @@ f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", "model.multi_stage.debug_tensor_parallel=True", "model.distributed.reproducible_init=True", + "model.distributed.timeout=10", "training.train_iters=2", "training.num_workers=0", + "training.timeout=30", "batch.batch_size=8", "batch.sequence_length=512", "data.datasets.Training.type=slice", From 2e416b188e637d67f0c7acb4102c15275c19b4d0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Mar 2025 14:21:28 -0400 Subject: [PATCH 06/16] fix --- fast_llm/engine/multi_stage/fsdp.py | 10 ++++++---- tests/test_checkpoint.py | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index b3cd4e2f..00f76d67 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -333,7 +333,9 @@ def reset_gradients(self) -> None: assert buffer.grad is None buffer.param_grad_is_zero = True - def reduce_gradients(self, accumulate=False) -> None: + def reduce_gradients( + self, distributed: Distributed, accumulate: bool = False, allow_no_grad: bool = False + ) -> None: # Reduce the buffer, then copy (add) to actual grads. # Run in a separate cuda stream to allow communication overlap. # TODO: Allow partial FSDP @@ -343,10 +345,10 @@ def reduce_gradients(self, accumulate=False) -> None: return for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values()): if buffer.param_grad_is_zero: # noqa - assert self.is_tied_weight_copy or meta.allow_no_grad, meta + assert allow_no_grad or meta.allow_no_grad, meta triton_fill(buffer.grad_buffer, 0) # noqa - if self._sequence_parallel_grads is not None and self._distributed.tensor_group: - all_reduce(self._sequence_parallel_grads, group=self._distributed.tensor_group) + if self._sequence_parallel_grads is not None and distributed.tensor_group: + all_reduce(self._sequence_parallel_grads, group=distributed.tensor_group) if self._fsdp_dim.size > 1: full_precision_gradients = self._grad_buffer_local_shard.dtype == self._grad_shard.dtype out = self._grad_shard if full_precision_gradients else self._grad_buffer_local_shard diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index ae3447f0..74b60c39 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -438,6 +438,7 @@ def test_load_pretrained_in_dp2_match_checkpoint(): assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa +@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.slow @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): From 811739ab23d0fff0895856685507d44596619a9e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Mar 2025 14:59:29 -0400 Subject: [PATCH 07/16] fix --- fast_llm/engine/multi_stage/fsdp.py | 5 +++++ fast_llm/engine/multi_stage/stage.py | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 00f76d67..692e2dbd 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -372,6 +372,11 @@ def _parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: end = self.index_buffer_to_shard(self.get_parameter_end_in_buffer(parameter_name)) return begin, end + def invalidate_buffer(self) -> None: + # Buffer is no longer valid (Updated weights or overwritten by other stage) + assert self._mode.support_forward + self._is_restored = False + def parameter_global_to_shard( self, global_param: torch.Tensor | SafeTensorSlice, parameter_name: str ) -> torch.Tensor: diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 78b336c9..8d343125 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -141,7 +141,7 @@ def reduce_gradients(self, accumulate=False) -> None: assert self._is_restored assert self._mode.support_backward for fsdp in self._fsdps: - fsdp.reduce_gradients(accumulate) + fsdp.reduce_gradients(self._distributed, accumulate, self._is_tied_weight_copy) if self._config.debug_param_gradients: log_tensor( "Reduced gradient shard", @@ -174,6 +174,9 @@ def invalidate_buffer(self) -> None: # Buffer is no longer valid (Updated weights or overwritten by other stage) assert self._mode.support_forward self._is_restored = False + # TODO: Frozen weights fsdps may not be invalidated on weight update. + for fsdp in self._fsdps: + fsdp.invalidate_buffer() def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any], i: int) -> None: if ( From 420bedce853333b497235cc1ac7bfbe1ee26bb9d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Mar 2025 21:08:52 -0400 Subject: [PATCH 08/16] separate shard wip --- fast_llm/engine/checkpoint/distributed.py | 14 +- fast_llm/engine/checkpoint/safe_load.py | 4 +- fast_llm/engine/multi_stage/fast_llm_model.py | 2 +- fast_llm/engine/multi_stage/fsdp.py | 18 +- fast_llm/engine/multi_stage/multi_stage.py | 196 ++++++++++++------ fast_llm/engine/multi_stage/stage_base.py | 9 - tests/test_checkpoint.py | 12 +- 7 files changed, 157 insertions(+), 98 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 7c0b9a26..2d334cbd 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -45,9 +45,9 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No # TODO: More safety checks loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata) - num_shards = self.get_num_shards(config) + # num_shards = self.get_num_shards(config) shard_names = self.get_shard_names(config) - Assert.eq(metadata.shards[:num_shards], list(shard_names)) + Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) same_format = ( loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) @@ -67,6 +67,8 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No ) as f: # TODO: Does this copy twice? self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) + # self._model.state_shard_sizes + else: log_main_rank("Checkpoint format doesn't match, using safe load") self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) @@ -98,14 +100,18 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No ) for loaded_fsdp, loaded_fsdp_shards in zip( loaded_stage.fsdps, - loaded_stage_shards.split(loaded_model._fsdp_shard_sizes[loaded_shard_index], 1), + loaded_stage_shards.split( + loaded_model._fsdp_weight_shard_sizes[loaded_shard_index], 1 + ), strict=True, ): for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()): self_stage_shards = self_shard_split[self_shard_index] for self_fsdp, self_fsdp_shards in zip( self_stage.fsdps, - self_stage_shards.split(self._model._fsdp_shard_sizes[self_shard_index], 1), + self_stage_shards.split( + self._model._fsdp_weight_shard_sizes[self_shard_index], 1 + ), strict=True, ): self_fsdp._copy_shard_overlaps( # noqa diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 8a7701a7..10f25c6c 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -44,7 +44,7 @@ def __enter__(self) -> "SafeLoad": zip(self._model.stages_on_device.values(), shard_split) ): for fsdp, fsdp_shard in zip( - stage.fsdps, stage_shard.split(self._model._fsdp_shard_sizes[shard_index]), strict=True + stage.fsdps, stage_shard.split(self._model._fsdp_weight_shard_sizes[shard_index]), strict=True ): self._loaded += fsdp.reset_shard_pad(fsdp_shard) return self @@ -101,7 +101,7 @@ def _check_missing(self, errors: list[str]) -> None: zip(self._model.stages_on_device.values(), shard_split) ): for fsdp, fsdp_shard in zip( - stage.fsdps, stage_shard.split(self._model._fsdp_shard_sizes[shard_index]), strict=True + stage.fsdps, stage_shard.split(self._model._fsdp_weight_shard_sizes[shard_index]), strict=True ): buffer = fsdp.reconstruct_from_shard(fsdp_shard) for parameter_name, parameter in fsdp.split_buffer(buffer).items(): diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index b268ec29..e1af5ab8 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -98,7 +98,7 @@ def initialize_weights(self, timeout: float | None = None) -> None: def _finalize_load(self, reset_optimizer: bool = True) -> None: if reset_optimizer: - triton_fill(self._state_shard[1:], 0.0) + triton_fill(self._flat_shard[self._weight_shard_size :], 0.0) if self._mode.support_forward: self.invalidate_buffers() self._is_loaded = True diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 692e2dbd..ffd2d8b0 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -75,25 +75,27 @@ def __init__( ) # TODO: Use parallel_dim property instead? - shard_dim = TensorDim("flat_shard", (self._parameter_count + self._global_pad) // self._fsdp_dim.size) + weight_shard_dim = TensorDim("weight_shard", (self._parameter_count + self._global_pad) // self._fsdp_dim.size) + grad_shard_dim = TensorDim("grad_shard", weight_shard_dim.size if self._requires_grad else 0) self._weight_shard_meta = TensorMeta.from_dims( - (shard_dim,), + (weight_shard_dim,), tensor_name=f"{self._name}_weight_shard", dtype=self._optimization_dtype.torch, ) + # TODO: Distinguish grad and optimizer shard? self._grad_shard_meta = TensorMeta.from_dims( - (shard_dim,), + (grad_shard_dim,), tensor_name=f"{self._name}_grad_shard", dtype=self._optimization_dtype.torch, ) self._weight_buffer_meta = TensorMeta.from_dims( - (TensorDim("weight_buffer", shard_dim.size * self._fsdp_dim.size),), + (TensorDim("weight_buffer", weight_shard_dim.size * self._fsdp_dim.size),), tensor_name=f"{self._name}_weight_buffer", dtype=self._training_dtype.torch, ) self._grad_buffer_meta = TensorMeta.from_dims( - (TensorDim("grad_buffer", shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), + (TensorDim("grad_buffer", weight_shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), tensor_name=f"{self._name}_grad_buffer", dtype=self._gradient_buffer_dtype.torch, ) @@ -114,9 +116,9 @@ def requires_grad(self) -> bool: def weight_shard_meta(self) -> TensorMeta: return self._weight_shard_meta - # @property - # def grad_shard_meta(self) -> TensorMeta: - # return self._grad_shard_meta + @property + def grad_shard_meta(self) -> TensorMeta: + return self._grad_shard_meta @property def weight_buffer_meta(self) -> TensorMeta: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 8e56c5bb..343cba45 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -23,14 +23,22 @@ logger = logging.getLogger(__name__) +class ShardName: + weights = "weights" + grads = "grads" + + class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False - _state_shard: torch.Tensor - _weight_shard: torch.Tensor - _grad_shard: torch.Tensor - _optimizer_shard: torch.Tensor + _flat_shard: torch.Tensor + _shards: dict[str, torch.Tensor] + _shards_names: tuple[str, ...] + # _state_shard: torch.Tensor + # _weight_shard: torch.Tensor + # _grad_shard: torch.Tensor + # _optimizer_shards: list[torch.Tensor] _distributed: Distributed _mode: StageMode @@ -38,6 +46,7 @@ def __init__( self, config: FastLLMModelConfig, *, + # TODO: No longer needed in __init__, move to setup? optimizer_state_names: tuple[str, ...] = (), verbose: bool = True, # A filter to create only a subset of the stages. Used for model conversion. @@ -48,6 +57,7 @@ def __init__( self._training = None self._verbose = verbose self._stage_filter = stage_filter + self._optimizer_state_names = optimizer_state_names stage_splits = self._split_into_stages() self._num_stages = len(stage_splits) - 1 @@ -119,36 +129,57 @@ def __init__( self._stage_shard_indices = { stage_index: shard_index for shard_index, stage_index in enumerate(self._stages_on_device) } - self._fsdp_shard_sizes = [ + + self._fsdp_weight_shard_sizes = [ [fsdp.weight_shard_meta.numel() for fsdp in stage.fsdps] for stage in self._stages_on_device.values() ] - self._stage_shard_sizes = [sum(shard_sizes) for shard_sizes in self._fsdp_shard_sizes] + self._stage_weight_shard_sizes = [sum(shard_sizes) for shard_sizes in self._fsdp_weight_shard_sizes] + self._weight_shard_size = sum(self._stage_weight_shard_sizes) + + self._fsdp_grad_shard_sizes = [ + [fsdp.grad_shard_meta.numel() for fsdp in stage.fsdps] for stage in self._stages_on_device.values() + ] + self._stage_grad_shard_sizes = [sum(shard_sizes) for shard_sizes in self._fsdp_grad_shard_sizes] + self._grad_shard_size = sum(self._stage_grad_shard_sizes) + # TODO: Support non-unique data type. - stage_shard_dtype = get_unique( + self._shard_dtype = get_unique( [fsdp.weight_shard_meta.dtype for stage in self._stages_on_device.values() for fsdp in stage.fsdps] ) - - self._state_shard_names = ("weights",) + optimizer_state_names - - shard_dim = TensorDim("flat_shard", sum(self._stage_shard_sizes)) - # TODO: Avoid unnecessary shards (frozen weights or shard identical to buffer) + # self._shard_names = (ShardName.weights,) + self._optimizer_state_names + (ShardName.grads,) self._weight_shard_meta = TensorMeta.from_dims( - (shard_dim,), + (TensorDim("weight_shard", self._weight_shard_size),), tensor_name=f"multi_stage_weight_shard", - dtype=stage_shard_dtype, - ) - # TODO !!!!!!: Remove for frozen weights. - self._state_shard_meta = TensorMeta.from_dims( - (TensorDim("state_shards", self.num_state_shards), shard_dim), - tensor_name=f"multi_stage_state_shard", - dtype=stage_shard_dtype, + dtype=self._shard_dtype, ) - self._full_shards_meta = TensorMeta.from_dims( - (TensorDim("shards", self.num_shards), shard_dim), - tensor_name=f"multi_stage_state_shard", - dtype=stage_shard_dtype, + self._grad_shard_meta = TensorMeta.from_dims( + (TensorDim("grad_shard", self._grad_shard_size),), + tensor_name=f"multi_stage_grad_shard", + dtype=self._shard_dtype, ) + # state_shard_sizes={"weights":self._weight_shard_size, **{optimizer_state_name:self._grad_shard_size for optimizer_state_name in optimizer_state_names}, "grads":self._grad_shard_size} + + # self._shard_sizes=(weight_shard_size,)+(grad_shard_size,)*self.num_state_shards + # self._state_shard_sizes=self._shard_sizes[:-1] + + # TODO: Avoid unnecessary shards (frozen weights or shard identical to buffer) + # self._weight_shard_meta = TensorMeta.from_dims( + # (TensorDim("weight_shard", weight_shard_size),), + # tensor_name=f"multi_stage_weight_shard", + # dtype=self._shard_dtype, + # ) + # self._state_shard_meta = TensorMeta.from_dims( + # (TensorDim("state_shards", sum(self._state_shard_sizes)),), + # tensor_name=f"multi_stage_state_shard", + # dtype=self._shard_dtype, + # ) + # self._full_shards_meta = TensorMeta.from_dims( + # (TensorDim("shards", sum(self._shard_sizes)),), + # tensor_name=f"multi_stage_state_shard", + # dtype=self._shard_dtype, + # ) + # contents: buffer_index -> set[stage_index] # indices: stage_index -> buffer_index @@ -235,31 +266,37 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) self._grad_buffer_sizes ) + self._shards_names = () + if self._mode.on_device: + self._shards_names += (ShardName.weights,) + if self._mode.support_training: + self._shards_names += self._optimizer_state_names + if self._mode.support_backward: + self._shards_names += (ShardName.grads,) + if self._mode.on_device: - num_shards = ( - self._full_shards_meta.size(0) - if self._mode.support_training - else 2 if self._mode.support_backward else 1 + shard_sizes = [ + self._weight_shard_size if shard_name == ShardName.weights else self._grad_shard_size + for shard_name in self._shards_names + ] + full_shards_meta = TensorMeta.from_dims( + (TensorDim("", sum(shard_sizes)),), + tensor_name=f"", + dtype=self._shard_dtype, ) - allocated += (mem := num_shards * self._full_shards_meta.memory_usage // self._full_shards_meta.size(0)) + allocated += (mem := full_shards_meta.memory_usage) if self._verbose: log_model_parallel_main_rank( - f">>> Allocating {self.num_shards} x {len(self._stage_shard_sizes)}" - f" shards ({mem / 2 ** 20:,.2f} MiB)" + f">>> Allocating {len(self._shards_names)} shards ({mem / 2 ** 20:,.2f} MiB)" ) - shards = torch.empty_like(self._full_shards_meta[:num_shards], device=self._distributed.device) + self._flat_shard = torch.empty_like(full_shards_meta, device=self._distributed.device) if self._verbose: log_model_parallel_main_rank(f"Total allocated: {allocated / 2 ** 20:,.2f} MiB") - self._weight_shard = shards[0] - weight_shard_split = self._weight_shard.split(self._stage_shard_sizes) - if self._mode.support_backward: - self._state_shard = shards[:-1] - if self._mode.support_training: - self._optimizer_shard = shards[1:-1] - self._grad_shard = shards[-1] - grad_shard_split = self._grad_shard.split(self._stage_shard_sizes) - else: - self._state_shard = shards + + self._shards = { + shard_name: shard + for shard_name, shard in zip(self._shards_names, self._flat_shard.split(shard_sizes), strict=True) + } # Setup the tied parameter process groups for tied_parameter in self._tied_parameters.values(): @@ -271,26 +308,30 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) weight_buffer_index = self._weight_buffer_indices.get(stage_index) grad_buffer_index = self._grad_buffer_indices.get(stage_index) stage_weight_buffers = ( - weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[stage_index]].split( + weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[stage_index]].split( # noqa self._fsdp_weight_buffer_sizes[stage_index] ) if self._mode.support_forward and weight_buffer_index is not None else None ) stage_grad_buffers = ( - grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[stage_index]].split( + grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[stage_index]].split( # noqa self._fsdp_grad_buffer_sizes[stage_index] ) if self._mode.support_backward and grad_buffer_index is not None else None ) stage_weight_shards = ( - weight_shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) + self._shards[ShardName.weights] + .split(self._stage_weight_shard_sizes)[shard_index] + .split(self._fsdp_weight_shard_sizes[shard_index]) if self._mode.on_device and shard_index is not None else None # noqa ) stage_grad_shards = ( - grad_shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) # noqa + self._shards[ShardName.grads] + .split(self._stage_weight_shard_sizes)[shard_index] + .split(self._fsdp_weight_shard_sizes[shard_index]) if self._mode.support_backward and shard_index is not None else None ) @@ -318,13 +359,24 @@ def get_param_groups( assert self._is_setup assert self._mode.support_training # Setup the optimizer param groups. - optimizer_shards_split = [shard.split(self._stage_shard_sizes) for shard in self._optimizer_shard.unbind()] + optimizer_shards_split = { + shard_name: self._shards[shard_name].split( + self._stage_weight_shard_sizes if shard_name == ShardName.weights else self._stage_grad_shard_sizes + ) + for shard_name in self._optimizer_state_names + } param_groups, grads_for_norm = [], [] for stage_index, stage in self._stages_on_device.items(): shard_index = self._stage_shard_indices.get(stage_index) stage_optimizer_shards = { - name: shard_split[shard_index].split(self._fsdp_shard_sizes[shard_index]) - for name, shard_split in zip(self._state_shard_names[1:], optimizer_shards_split) + shard_name: shard_split[shard_index].split( + ( + self._fsdp_weight_shard_sizes + if shard_name == ShardName.weights + else self._stage_grad_shard_sizes + )[shard_index] + ) + for shard_name, shard_split in optimizer_shards_split.items() } stage_param_groups, stage_grads_for_norm = stage.get_param_groups( stage_optimizer_shards, @@ -337,9 +389,15 @@ def get_param_groups( return param_groups, grads_for_norm - @property - def state_shard_meta(self) -> TensorMeta: - return self._state_shard_meta + # @property + # def state_shard_meta(self) -> TensorMeta: + # return self._state_shard_meta + + def get_shard_meta(self, name: str) -> TensorMeta: + assert self._is_setup + if name not in self._shards_names: + raise KeyError(f"Unknown shard name {name}") + return self._weight_shard_meta if name == ShardName.weights else self._grad_shard_meta @property def support_forward(self) -> bool: @@ -364,17 +422,17 @@ def base_model(self) -> BaseModel: def stages(self) -> list[Stage]: return self._stages - @property - def state_shard(self) -> torch.Tensor: - return self._state_shard + # @property + # def state_shards(self) -> tuple[torch.Tensor]: + # return self._shards[:-1] if self._mode.support_backward else self._shards - @property - def num_shards(self) -> int: - return len(self._state_shard_names) + 1 + # @property + # def num_shards(self) -> int: + # return len(self._state_shard_names) + 1 - @property - def num_state_shards(self) -> int: - return len(self._state_shard_names) + # @property + # def num_state_shards(self) -> int: + # return len(self._state_shard_names) @property def stages_on_device(self) -> dict[int, Stage]: @@ -394,11 +452,11 @@ def grad_buffer_indices(self) -> dict[int, int]: @property def state_shard_names(self) -> tuple[str, ...]: - return self._state_shard_names + return self._shards_names[:-1] if self._mode.support_backward else self._shards_names @property def stage_shard_sizes(self) -> list[int]: - return self._stage_shard_sizes + return self._stage_weight_shard_sizes @property def parameter_names(self) -> list[str]: @@ -427,13 +485,13 @@ def train(self, mode: bool = True) -> None: def get_state_tensor_iterator( self, shard_names: list[str], data_type: DataType | None = None ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: - for i, shard_name in enumerate(shard_names): - shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) + for shard_name in shard_names: + shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) for shard_index, (stage, shard) in enumerate( zip(self._stages_on_device.values(), shard_split, strict=True) ): for name, tensor in stage._export_shard( - shard.split(self._fsdp_shard_sizes[shard_index]), data_type=data_type + shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type ): # noqa yield name, shard_name, tensor @@ -447,9 +505,9 @@ def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torc return 0 shard_index = self._stage_shard_indices[self._parameter_stages[parameter_name]] stage_shards = ( - self._state_shard[self._state_shard_names.index(shard_name)] - .split(self._stage_shard_sizes, 0)[shard_index] - .split(self._fsdp_shard_sizes[shard_index]) + self._shards[shard_name] + .split(self._stage_weight_shard_sizes, 0)[shard_index] + .split(self._fsdp_weight_shard_sizes[shard_index]) ) return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shards, tensor) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 43d01eec..4eb70f62 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -25,15 +25,6 @@ class StageBase(Configurable[StageConfig]): _distributed: Distributed _mode: StageMode - # _weight_shard: torch.Tensor - # _grad_shard: torch.Tensor - # _weight_buffer: torch.Tensor - # _grad_buffer: torch.Tensor - # _sequence_parallel_grads: torch.Tensor - # _weight_buffer_local_shard: torch.Tensor - # _grad_buffer_local_shard: torch.Tensor - # _parameter_buffers: list[torch.nn.Parameter] - def __init__( self, *, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 74b60c39..b2bf2504 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -419,16 +419,18 @@ def test_load_pretrained_in_dp2_match_checkpoint(): ref_model = TEST_MODEL_CLS(config_ref) test_model = TEST_MODEL_CLS(config_test) - weight_shard_ref_split = shard_ref[0].split(ref_model._stage_shard_sizes) - weight_shards_test_split = [shard_test[0].split(test_model._stage_shard_sizes) for shard_test in shards_test] + weight_shard_ref_split = shard_ref[0].split(ref_model._stage_weight_shard_sizes) + weight_shards_test_split = [ + shard_test[0].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test + ] for shard_test in shards_test: assert (shard_test[1:] == 0).all() # noqa - assert len(ref_model._stage_shard_sizes) == len(test_model._stage_shard_sizes) + assert len(ref_model._stage_weight_shard_sizes) == len(test_model._stage_weight_shard_sizes) for i, stage_shard_ref in enumerate(weight_shard_ref_split): assert ( - test_model._stage_shard_sizes[i] - == ref_model._stage_shard_sizes[i] // 2 + (-ref_model._stage_shard_sizes[i] // 2) % 32 + test_model._stage_weight_shard_sizes[i] + == ref_model._stage_weight_shard_sizes[i] // 2 + (-ref_model._stage_weight_shard_sizes[i] // 2) % 32 ) stage_shard_test = torch.concatenate( From 9313c88dfe6b4fd65e468c2a4eaa5f3c0be5d9e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 17 Mar 2025 21:41:22 -0400 Subject: [PATCH 09/16] separate shards --- fast_llm/engine/checkpoint/config.py | 3 - fast_llm/engine/checkpoint/distributed.py | 80 ++++----- fast_llm/engine/checkpoint/safe_load.py | 85 +++++----- fast_llm/engine/checkpoint/state_dict.py | 4 +- fast_llm/engine/multi_stage/fsdp.py | 11 +- fast_llm/engine/multi_stage/multi_stage.py | 184 ++++++++++----------- tests/test_checkpoint.py | 47 +++--- 7 files changed, 197 insertions(+), 217 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 92f1165d..f0a9aeea 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -251,8 +251,5 @@ def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"): def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"): pass - def get_num_shards(self, config: CheckpointStateConfigBase) -> int: - return len(self._model.state_shard_names) if config.optimizer_state else 1 - def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]: return self._model.state_shard_names if config.optimizer_state else self._model.state_shard_names[:1] diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 2d334cbd..f0bb900e 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -36,7 +36,7 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( - tensors={"state_shard": self._model.state_shard[: self.get_num_shards(config)]}, + tensors={f"{shard_name}_shard": self._model.get_shard(shard_name) for shard_name in metadata.shards}, filename=config.path / f"rank_{self._model.config.distributed.rank}.safetensors", metadata=export_safetensors_metadata(serialized_metadata), ) @@ -45,8 +45,9 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No # TODO: More safety checks loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata) - # num_shards = self.get_num_shards(config) shard_names = self.get_shard_names(config) + # Make sure all shards to load are in the checkpoint. + Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) same_format = ( @@ -65,14 +66,22 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No framework="pt", device=str(self._model.distributed.device), ) as f: - # TODO: Does this copy twice? - self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) - # self._model.state_shard_sizes + if "state_shard" in f.keys(): + # Old format `state_shard` with shape `(num_shards, shard_size) + # TODO v0.3: Use checkpoint version? Drop support? + for shard_name in shard_names: + self._model.get_shard(shard_name).copy_( + f.get_slice("state_shard")[metadata.shards.index(shard_name)] + ) + else: + # TODO: Does this copy twice? + for shard_name in shard_names: + self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard")) else: log_main_rank("Checkpoint format doesn't match, using safe load") self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) - with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context: + with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context: for rank in range(loaded_config.distributed.world_size): loaded_model = self._model.__class__( loaded_config.to_copy({("distributed", "rank"): rank}), @@ -84,41 +93,32 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No # TODO: skip shards without overlap. with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard - loaded_shard = f.get_slice("state_shard")[:num_shards] - loaded_model.state_shard_meta.validate(loaded_shard) + if "state_shard" in f.keys(): + # Old format `state_shard` with shape `(num_shards, shard_size) + # TODO v0.3: Use checkpoint version? Drop support? + loaded_shards = { + shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)] + for shard_name in shard_names + } + else: + loaded_shards = { + shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names + } - # TODO: Improve num shard selection. - self_shard_split = self._model.state_shard[: loaded_shard.size(0)].split( - self._model.stage_shard_sizes, 1 - ) - loaded_shard_split = loaded_shard.split(loaded_model.stage_shard_sizes, 1) + for shard_name, loaded_shard in loaded_shards.items(): + loaded_model.get_shard_meta(shard_name).validate(loaded_shard) + + self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names} counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device) - for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()): - loaded_stage_shards = loaded_shard_split[loaded_shard_index].to( - self._model.distributed.device - ) - for loaded_fsdp, loaded_fsdp_shards in zip( - loaded_stage.fsdps, - loaded_stage_shards.split( - loaded_model._fsdp_weight_shard_sizes[loaded_shard_index], 1 - ), - strict=True, - ): - for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()): - self_stage_shards = self_shard_split[self_shard_index] - for self_fsdp, self_fsdp_shards in zip( - self_stage.fsdps, - self_stage_shards.split( - self._model._fsdp_weight_shard_sizes[self_shard_index], 1 - ), - strict=True, - ): - self_fsdp._copy_shard_overlaps( # noqa - loaded_fsdp, - self_fsdp_shards.unbind(0), - loaded_fsdp_shards, - counter, - self._model.distributed.device, - ) + for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): + for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): + self_fsdp.copy_shard_overlaps( + loaded_fsdp, + self_fsdp_shards, + loaded_fsdp_shards, + counter, + self._model.distributed.device, + ) + context.mark_as_loaded(counter.item()) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 10f25c6c..2eec57e0 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -24,11 +24,11 @@ class SafeLoad: In case of failure, it will attempt to find out as precisely as possible where the problem comes from. """ - def __init__(self, model: "FastLLMModel", *, num_shards: int, timeout: float | None = None): + def __init__(self, model: "FastLLMModel", *, shard_names: tuple[str, ...], timeout: float | None = None): self._model = model self._distributed = self._model.distributed - self._num_shards = num_shards - self._self_shard = self._model.state_shard[: self._num_shards] + # self._num_shards = num_shards + self._self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names} self._timeout = timeout def __enter__(self) -> "SafeLoad": @@ -36,17 +36,12 @@ def __enter__(self) -> "SafeLoad": self._loaded_parameters = {} # Track the number of loaded entries. # Use nan to mark non-loaded entries. - triton_fill(self._self_shard, math.nan) + for self_shard in self._self_shards.values(): + triton_fill(self_shard, math.nan) # Reset and count shard pads - for shard in self._model.state_shard[: self._num_shards]: - shard_split = shard.split(self._model.stage_shard_sizes, 0) - for shard_index, (stage, stage_shard) in enumerate( - zip(self._model.stages_on_device.values(), shard_split) - ): - for fsdp, fsdp_shard in zip( - stage.fsdps, stage_shard.split(self._model._fsdp_weight_shard_sizes[shard_index]), strict=True - ): - self._loaded += fsdp.reset_shard_pad(fsdp_shard) + for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards): + for fsdp_shard in fsdp_shards.values(): + self._loaded += fsdp.reset_shard_pad(fsdp_shard) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -75,18 +70,19 @@ def _validate(self) -> None: logger.info(f"{self._loaded:,} state entries loaded successfully") def _check_counter(self, errors: list[str]) -> None: - to_load = self._self_shard.numel() + to_load = sum(self_shard.numel() for self_shard in self._self_shards.values()) if self._loaded != to_load: # Ensure the right amount of weights is loaded. errors.append(f"Loaded a total of {self._loaded:,}, state entries, expected {to_load:,}") def _check_missing(self, errors: list[str]) -> None: # Ensure the loaded weights have a 1-1 mapping by looking for nans. - missing = self._self_shard.new_zeros([], dtype=torch.int64) + missing = torch.zeros([], dtype=torch.int64, device=self._distributed.device) # Count nans in slices of 100M parameters to limit memory usage. # TODO: Find better solution (triton kernel?) - for shard_slice in self._self_shard.flatten().split(100000000): - missing += shard_slice.isnan().sum() + for shard in self._self_shards.values(): + for shard_slice in shard.flatten().split(100000000): + missing += shard_slice.isnan().sum() local_missing = missing.item() if self._distributed.world_group is not None: all_reduce(missing, group=self._distributed.world_group) @@ -95,37 +91,32 @@ def _check_missing(self, errors: list[str]) -> None: errors.append(f"{global_missing:,} state entries failed to load or corrupted (local={local_missing:,}).") # Determine where the missing values are coming from. global_total, local_total = 0, 0 - for shard_name, shard_ in zip(self._model.state_shard_names[: self._num_shards], self._self_shard): - shard_split = shard_.split(self._model.stage_shard_sizes, 0) - for shard_index, (stage, stage_shard) in enumerate( - zip(self._model.stages_on_device.values(), shard_split) - ): - for fsdp, fsdp_shard in zip( - stage.fsdps, stage_shard.split(self._model._fsdp_weight_shard_sizes[shard_index]), strict=True - ): - buffer = fsdp.reconstruct_from_shard(fsdp_shard) - for parameter_name, parameter in fsdp.split_buffer(buffer).items(): - missing_for_param = parameter.isnan().sum().item() - if missing_for_param > 0: - global_total += missing_for_param - local_values = fsdp.split_shard(fsdp_shard)[parameter_name] - local_missing_for_param = local_values.isnan().sum().item() - local_total += local_missing_for_param - errors.append( - f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}" - f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" - ) - missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() - if missing_for_pad > 0: - global_total += missing_for_pad - local_missing_for_pad = ( - fsdp_shard[-fsdp._shard_pad :].isnan().sum().item() if fsdp._shard_pad > 0 else 0 - ) - local_total += local_missing_for_pad + for stage, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards): + for shard_name, fsdp_shard in fsdp_shards.items(): + buffer = fsdp.reconstruct_from_shard(fsdp_shard) + for parameter_name, parameter in fsdp.split_buffer(buffer).items(): + missing_for_param = parameter.isnan().sum().item() + if missing_for_param > 0: + global_total += missing_for_param + local_values = fsdp.split_shard(fsdp_shard)[parameter_name] + local_missing_for_param = local_values.isnan().sum().item() + local_total += local_missing_for_param errors.append( - f"{missing_for_pad:,} values missing out of {fsdp._global_pad:,} for padding in stage {stage.index}, shard {shard_name}" - f" (locally {local_missing_for_pad:,} out of {fsdp._shard_pad:,})" + f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}" + f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" ) + missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() + if missing_for_pad > 0: + global_total += missing_for_pad + local_missing_for_pad = ( + fsdp_shard[-fsdp._shard_pad :].isnan().sum().item() if fsdp._shard_pad > 0 else 0 + ) + local_total += local_missing_for_pad + errors.append( + f"{missing_for_pad:,} values missing out of {fsdp._global_pad:,} for padding in stage {stage.index}, shard {shard_name}" + f" (locally {local_missing_for_pad:,} out of {fsdp._shard_pad:,})" + ) + if global_total != global_missing: errors.append( f"Incorrect global breakdown of missing state entries (expected {global_missing:,}, got {global_total:,})" @@ -137,7 +128,7 @@ def _check_missing(self, errors: list[str]) -> None: def _check_parameters(self, errors: list[str]) -> None: loaded_shard_names = set(self._loaded_parameters) - shard_names = set(self._model.state_shard_names[: self._num_shards]) + shard_names = set(self._self_shards) if loaded_shard_names != shard_names: errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}") for shard_name in shard_names & loaded_shard_names: diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 5d2e913c..f9cfe237 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -72,7 +72,7 @@ def _serialize_metadata( return metadata.to_serialized() def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: - with SafeLoad(self._model, num_shards=self.get_num_shards(config), timeout=config.timeout) as context: + with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from # `state_dict` that are ready for conversion, # and return a dict containing the converted tensors(s). @@ -145,7 +145,7 @@ def _load_weights( ) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]: metadata = self.load_metadata(config) shard_names = self.get_shard_names(config) - Assert.eq(metadata.shards[: self.get_num_shards(config)], list(shard_names)) + Assert.leq(set(shard_names), set(metadata.shards)) for file_name in set(metadata.metadata["state_index"].values()): logger.info(f"Loading from {config.path / file_name}") with safetensors.safe_open( diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index ffd2d8b0..7c8344ef 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -413,11 +413,11 @@ def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, devic ) return index - def _copy_shard_overlaps( + def copy_shard_overlaps( self, loaded_fsdp: "FSDP", - shards: list[torch.Tensor], - loaded_shards: list[torch.Tensor], + shards: dict[str, torch.Tensor], + loaded_shards: dict[str, torch.Tensor], counter: torch.Tensor, device: torch.device, ) -> None: @@ -425,6 +425,7 @@ def _copy_shard_overlaps( See MultiStage._load_partial. TODO: Not intended to work with frozen weights, need to enforce. """ + Assert.eq(set(shards), set(loaded_shards)) index_overlap = [name for name in loaded_fsdp._parameter_metas if name in self._parameter_metas] for name in index_overlap: overlap_index_map = self.parameter_global_to_shard( @@ -435,6 +436,6 @@ def _copy_shard_overlaps( overlap_count = overlap_mask.sum() begin, end = self._parameter_range_in_shard(name) - for shard, loaded_shard in zip(shards, loaded_shards): - shard[begin:end][overlap_mask] = loaded_shard[overlap_index_map_masked] + for shard_name, shard in shards.items(): + shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] counter += overlap_count diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 343cba45..c38a46b0 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -15,6 +15,7 @@ from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode +from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta @@ -34,11 +35,7 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): _is_setup: bool = False _flat_shard: torch.Tensor _shards: dict[str, torch.Tensor] - _shards_names: tuple[str, ...] - # _state_shard: torch.Tensor - # _weight_shard: torch.Tensor - # _grad_shard: torch.Tensor - # _optimizer_shards: list[torch.Tensor] + _shard_names: tuple[str, ...] _distributed: Distributed _mode: StageMode @@ -58,6 +55,7 @@ def __init__( self._verbose = verbose self._stage_filter = stage_filter self._optimizer_state_names = optimizer_state_names + self._all_shard_names = (ShardName.weights, *optimizer_state_names, ShardName.grads) stage_splits = self._split_into_stages() self._num_stages = len(stage_splits) - 1 @@ -158,31 +156,6 @@ def __init__( dtype=self._shard_dtype, ) - # state_shard_sizes={"weights":self._weight_shard_size, **{optimizer_state_name:self._grad_shard_size for optimizer_state_name in optimizer_state_names}, "grads":self._grad_shard_size} - - # self._shard_sizes=(weight_shard_size,)+(grad_shard_size,)*self.num_state_shards - # self._state_shard_sizes=self._shard_sizes[:-1] - - # TODO: Avoid unnecessary shards (frozen weights or shard identical to buffer) - # self._weight_shard_meta = TensorMeta.from_dims( - # (TensorDim("weight_shard", weight_shard_size),), - # tensor_name=f"multi_stage_weight_shard", - # dtype=self._shard_dtype, - # ) - # self._state_shard_meta = TensorMeta.from_dims( - # (TensorDim("state_shards", sum(self._state_shard_sizes)),), - # tensor_name=f"multi_stage_state_shard", - # dtype=self._shard_dtype, - # ) - # self._full_shards_meta = TensorMeta.from_dims( - # (TensorDim("shards", sum(self._shard_sizes)),), - # tensor_name=f"multi_stage_state_shard", - # dtype=self._shard_dtype, - # ) - - # contents: buffer_index -> set[stage_index] - # indices: stage_index -> buffer_index - # Pre-compute buffer specs. # TODO: Reduce code duplication. self._weight_buffer_contents, self._weight_buffer_indices = self._get_buffer_placement( @@ -190,6 +163,7 @@ def __init__( ) if self._verbose: log_model_parallel_main_rank(f"Weight buffer placement:\n{self._weight_buffer_indices}") + # TODO: Let stages worry about their buffer splitting? self._fsdp_weight_buffer_sizes = [ [fsdp.weight_buffer_meta.numel() for fsdp in stage.fsdps] for stage in self._stages ] @@ -241,68 +215,77 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) self._is_setup = True self._distributed = distributed self._mode = mode - self._base_model.setup(distributed) allocated = 0 # Allocate and split shards and buffers. if self._mode.support_forward: - allocated += (mem := self._weight_buffer_meta.memory_usage) - if self._verbose: - log_model_parallel_main_rank( - f">>> Allocating {len(self._weight_buffer_sizes)} weight buffers ({mem / 2 ** 20:,.2f} MiB)" - ) - weight_buffers = torch.empty_like(self._weight_buffer_meta, device=self._distributed.device).split( - self._weight_buffer_sizes - ) + weight_buffers, mem = self._allocate_buffers(self._weight_buffer_meta, self._weight_buffer_sizes, "weight") + allocated += mem + else: + weight_buffers = None if self._mode.support_backward: - allocated += (mem := self._grad_buffer_meta.memory_usage) - if self._verbose: - log_model_parallel_main_rank( - f">>> Allocating {len(self._grad_buffer_sizes)} grad buffers ({mem / 2 ** 20:,.2f} MiB)" - ) - grad_buffers = torch.empty_like(self._grad_buffer_meta, device=self._distributed.device).split( - self._grad_buffer_sizes - ) + grad_buffers, mem = self._allocate_buffers(self._grad_buffer_meta, self._grad_buffer_sizes, "grad") + allocated += mem + else: + grad_buffers = None - self._shards_names = () + self._shard_names = () if self._mode.on_device: - self._shards_names += (ShardName.weights,) + self._shard_names += (ShardName.weights,) if self._mode.support_training: - self._shards_names += self._optimizer_state_names + self._shard_names += self._optimizer_state_names if self._mode.support_backward: - self._shards_names += (ShardName.grads,) + self._shard_names += (ShardName.grads,) if self._mode.on_device: - shard_sizes = [ - self._weight_shard_size if shard_name == ShardName.weights else self._grad_shard_size - for shard_name in self._shards_names - ] - full_shards_meta = TensorMeta.from_dims( - (TensorDim("", sum(shard_sizes)),), - tensor_name=f"", - dtype=self._shard_dtype, - ) - allocated += (mem := full_shards_meta.memory_usage) - if self._verbose: - log_model_parallel_main_rank( - f">>> Allocating {len(self._shards_names)} shards ({mem / 2 ** 20:,.2f} MiB)" - ) - self._flat_shard = torch.empty_like(full_shards_meta, device=self._distributed.device) - if self._verbose: - log_model_parallel_main_rank(f"Total allocated: {allocated / 2 ** 20:,.2f} MiB") + allocated += self._allocate_shards() - self._shards = { - shard_name: shard - for shard_name, shard in zip(self._shards_names, self._flat_shard.split(shard_sizes), strict=True) - } + if self._verbose: + log_model_parallel_main_rank(f"Total allocated: {allocated / 2 ** 20:,.2f} MiB") # Setup the tied parameter process groups for tied_parameter in self._tied_parameters.values(): tied_parameter.setup(self._distributed) # Setup the layer shards and buffers. + self._setup_stages(weight_buffers, grad_buffers) + + self.train(self._mode.support_backward) + + def _allocate_buffers( + self, buffer_meta: TensorMeta, sizes: list[int], name: str + ) -> tuple[tuple[torch.Tensor, ...], int]: + mem = buffer_meta.memory_usage + if self._verbose: + log_model_parallel_main_rank(f">>> Allocating {len(sizes)} {name} buffers ({mem / 2 ** 20:,.2f} MiB)") + return torch.empty_like(buffer_meta, device=self._distributed.device).split(sizes), mem + + def _allocate_shards(self) -> int: + shard_sizes = [ + self._weight_shard_size if shard_name == ShardName.weights else self._grad_shard_size + for shard_name in self._shard_names + ] + full_shards_meta = TensorMeta.from_dims( + (TensorDim("", sum(shard_sizes)),), + tensor_name=f"", + dtype=self._shard_dtype, + ) + mem = full_shards_meta.memory_usage + if self._verbose: + log_model_parallel_main_rank(f">>> Allocating {len(self._shard_names)} shards ({mem / 2 ** 20:,.2f} MiB)") + self._flat_shard = torch.empty_like(full_shards_meta, device=self._distributed.device) + + self._shards = { + shard_name: shard + for shard_name, shard in zip(self._shard_names, self._flat_shard.split(shard_sizes), strict=True) + } + return mem + + def _setup_stages( + self, weight_buffers: tuple[torch.Tensor, ...] | None, grad_buffers: tuple[torch.Tensor, ...] | None + ) -> None: for stage_index, stage in enumerate(self._stages): shard_index = self._stage_shard_indices.get(stage_index) weight_buffer_index = self._weight_buffer_indices.get(stage_index) @@ -341,7 +324,7 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) else [] ) stage.setup( - distributed=distributed, + distributed=self._distributed, weight_shards=stage_weight_shards, grad_shards=stage_grad_shards, weight_buffers=stage_weight_buffers, @@ -351,8 +334,6 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) weight_buffer_shared_with=weight_buffer_shared_with, ) - self.train(self._mode.support_backward) - def get_param_groups( self, param_group_cls: type[ParamGroup] = ParamGroup ) -> tuple[list[ParamGroup], list[torch.Tensor]]: @@ -389,16 +370,17 @@ def get_param_groups( return param_groups, grads_for_norm - # @property - # def state_shard_meta(self) -> TensorMeta: - # return self._state_shard_meta - def get_shard_meta(self, name: str) -> TensorMeta: - assert self._is_setup - if name not in self._shards_names: + if name not in self._all_shard_names: raise KeyError(f"Unknown shard name {name}") return self._weight_shard_meta if name == ShardName.weights else self._grad_shard_meta + def get_shard(self, name: str) -> torch.Tensor: + assert self._is_setup + if name not in self._shard_names: + raise KeyError(f"Unknown shard name {name}") + return self._shards[name] + @property def support_forward(self) -> bool: assert self._is_setup @@ -422,18 +404,6 @@ def base_model(self) -> BaseModel: def stages(self) -> list[Stage]: return self._stages - # @property - # def state_shards(self) -> tuple[torch.Tensor]: - # return self._shards[:-1] if self._mode.support_backward else self._shards - - # @property - # def num_shards(self) -> int: - # return len(self._state_shard_names) + 1 - - # @property - # def num_state_shards(self) -> int: - # return len(self._state_shard_names) - @property def stages_on_device(self) -> dict[int, Stage]: return self._stages_on_device @@ -452,11 +422,13 @@ def grad_buffer_indices(self) -> dict[int, int]: @property def state_shard_names(self) -> tuple[str, ...]: - return self._shards_names[:-1] if self._mode.support_backward else self._shards_names + return self._shard_names[:-1] if self._mode.support_backward else self._shard_names - @property - def stage_shard_sizes(self) -> list[int]: - return self._stage_weight_shard_sizes + def _get_stage_shard_sizes(self, shard_name: str) -> list[int]: + return self._stage_weight_shard_sizes if shard_name == ShardName.weights else self._stage_grad_shard_sizes + + def _get_fsdp_shard_sizes(self, shard_name: str) -> list[list[int]]: + return self._fsdp_weight_shard_sizes if shard_name == ShardName.weights else self._fsdp_grad_shard_sizes @property def parameter_names(self) -> list[str]: @@ -483,7 +455,7 @@ def train(self, mode: bool = True) -> None: self._training = mode def get_state_tensor_iterator( - self, shard_names: list[str], data_type: DataType | None = None + self, shard_names: tuple[str, ...], data_type: DataType | None = None ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) @@ -511,6 +483,22 @@ def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torc ) return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shards, tensor) + def split_shards_by_fsdp( + self, shards: dict[str, torch.Tensor] + ) -> typing.Generator[tuple[Stage, FSDP, dict[str, torch.Tensor]], None, None]: + stage_shards = { + shard_name: shard.split(self._get_stage_shard_sizes(shard_name)) for shard_name, shard in shards.items() + } + for shard_index, stage in enumerate(self.stages_on_device.values()): + fsdp_shards = { + shard_name: stage_shards_[shard_index].split(self._get_fsdp_shard_sizes(shard_name)[shard_index]) + for shard_name, stage_shards_ in stage_shards.items() + } + for fsdp_index, fsdp in enumerate(stage.fsdps): + yield stage, fsdp, { + shard_name: fsdp_shards_[fsdp_index] for shard_name, fsdp_shards_ in fsdp_shards.items() + } + def _split_into_stages(self) -> list[int]: # Create stages (greedy split, could do better). stage_splits = [0] diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index b2bf2504..87460490 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -15,6 +15,7 @@ ModelConfigType, ) from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig from tests.common import ( @@ -206,12 +207,13 @@ def test_converted_distributed(): w = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") w0 = safetensors.torch.load_file(_CONVERT_PATH / "distributed_0" / "rank_0.safetensors") w1 = safetensors.torch.load_file(_CONVERT_PATH / "distributed_1" / "rank_0.safetensors") - assert w.keys() == w0.keys() == w1.keys() == {"state_shard"} - for key in w: - assert w[key][:1].shape == w0[key].shape, (key, w[key][:1].shape, w0[key].shape) - assert (w[key][:1] == w0[key]).all(), (w[key][:1], w0[key]) - assert w[key][:1].shape == w1[key].shape, (key, w[key][:1].shape, w1[key].shape) - assert (w[key][:1] == w1[key]).all(), (w[key][:1], w1[key]) + assert w.keys() >= {f"{ShardName.weights}_shard"} + assert w0.keys() == w1.keys() == {f"{ShardName.weights}_shard"} + for key in w0: + assert w[key].shape == w0[key].shape, (key, w[key].shape, w0[key].shape) + assert (w[key] == w0[key]).all(), (w[key], w0[key]) + assert w[key].shape == w1[key].shape, (key, w[key].shape, w1[key].shape) + assert (w[key] == w1[key]).all(), (w[key], w1[key]) @pytest.mark.depends(on=["test_convert_distributed_to_fast_llm", "test_convert_huggingface_to_fast_llm"]) @@ -251,10 +253,11 @@ def test_load_pretrained_distributed_checkpoint(): ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config.base_model, model.config.base_model) - weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._state_shard.device) - )["state_shard"] - assert (weight_shard == model._state_shard).all() + state_shards = safetensors.torch.load_file( + _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) + ) + for shard_name in model.state_shard_names: + assert (state_shards[f"{shard_name}_shard"] == model.get_shard(shard_name)).all() @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) @@ -274,9 +277,9 @@ def test_load_converted_distributed_checkpoint(): _compare_configs(config.base_model, model.config.base_model) _compare_configs(config.base_model, config_1.base_model) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._state_shard.device) - )["state_shard"][:1] - assert (weight_shard == model._state_shard).all() + _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) + )[f"{ShardName.weights}_shard"] + assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) @@ -290,9 +293,9 @@ def test_load_converted_fast_llm_checkpoint(): _compare_configs(config.base_model, model.config.base_model) _compare_configs(config.base_model, config_1.base_model) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._state_shard.device) - )["state_shard"][:1] - assert (weight_shard == model._state_shard).all() + _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) + )[f"{ShardName.weights}_shard"] + assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) @@ -315,9 +318,9 @@ def test_load_converted_huggingface_checkpoint(): _compare_configs(config.base_model, model.config.base_model) _compare_configs(config.base_model, config_1.base_model) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._state_shard.device) - )["state_shard"][:1] - assert (weight_shard == model._state_shard).all() + _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) + )[f"{ShardName.weights}_shard"] + assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_converted_huggingface_checkpoint"]) @@ -458,9 +461,9 @@ def test_load_distributed_checkpoint_dp2(): model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) _compare_configs(config.base_model, model.config.base_model) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._state_shard.device) - )["state_shard"][:1] - assert (weight_shard == model._state_shard).all() + _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) + )[f"{ShardName.weights}_shard"] + assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.slow From b9b017f7c92ed0c77788911ccc3780165fb43e6e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 17 Mar 2025 22:12:47 -0400 Subject: [PATCH 10/16] fix --- tests/test_checkpoint.py | 42 ++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 87460490..32f0c3f7 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -37,6 +37,8 @@ TEST_BASE_MODEL_CONFIG_CLS = TEST_MODEL_CONFIG_CLS.get_base_model_config_class() TEST_ARCHITECTURE_CONFIG_CLS = TEST_BASE_MODEL_CONFIG_CLS.architecture_class +WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" + @requires_cuda @pytest.mark.depends() @@ -207,8 +209,8 @@ def test_converted_distributed(): w = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") w0 = safetensors.torch.load_file(_CONVERT_PATH / "distributed_0" / "rank_0.safetensors") w1 = safetensors.torch.load_file(_CONVERT_PATH / "distributed_1" / "rank_0.safetensors") - assert w.keys() >= {f"{ShardName.weights}_shard"} - assert w0.keys() == w1.keys() == {f"{ShardName.weights}_shard"} + assert w.keys() >= {WEIGHT_SHARD_SAVE_NAME} + assert w0.keys() == w1.keys() == {WEIGHT_SHARD_SAVE_NAME} for key in w0: assert w[key].shape == w0[key].shape, (key, w[key].shape, w0[key].shape) assert (w[key] == w0[key]).all(), (w[key], w0[key]) @@ -278,7 +280,7 @@ def test_load_converted_distributed_checkpoint(): _compare_configs(config.base_model, config_1.base_model) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[f"{ShardName.weights}_shard"] + )[WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @@ -294,7 +296,7 @@ def test_load_converted_fast_llm_checkpoint(): _compare_configs(config.base_model, config_1.base_model) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[f"{ShardName.weights}_shard"] + )[WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @@ -319,7 +321,7 @@ def test_load_converted_huggingface_checkpoint(): _compare_configs(config.base_model, config_1.base_model) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[f"{ShardName.weights}_shard"] + )[WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @@ -415,19 +417,19 @@ def test_load_pretrained_in_dp2_match_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) _compare_configs(config_ref.base_model, config_test.base_model) - shard_ref = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors")["state_shard"] - shards_test = [ - safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors")["state_shard"] for i in range(2) - ] + shards_ref = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") + shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] ref_model = TEST_MODEL_CLS(config_ref) test_model = TEST_MODEL_CLS(config_test) - weight_shard_ref_split = shard_ref[0].split(ref_model._stage_weight_shard_sizes) + weight_shard_ref_split = shards_ref[WEIGHT_SHARD_SAVE_NAME].split(ref_model._stage_weight_shard_sizes) weight_shards_test_split = [ - shard_test[0].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test + shard_test[WEIGHT_SHARD_SAVE_NAME].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test ] for shard_test in shards_test: - assert (shard_test[1:] == 0).all() # noqa + for shard_name, shard in shard_test.items(): + if shard_name != WEIGHT_SHARD_SAVE_NAME: + assert (shard == 0).all() # noqa assert len(ref_model._stage_weight_shard_sizes) == len(test_model._stage_weight_shard_sizes) for i, stage_shard_ref in enumerate(weight_shard_ref_split): @@ -462,7 +464,7 @@ def test_load_distributed_checkpoint_dp2(): _compare_configs(config.base_model, model.config.base_model) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[f"{ShardName.weights}_shard"] + )[WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @@ -488,15 +490,16 @@ def test_load_pretrained_fast_llm_in_dp2(): / "checkpoint" / "1" / f"rank_{rank}.safetensors" - )["state_shard"] + ) test_shard = safetensors.torch.load_file( TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_fast_llm_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" - )["state_shard"] - assert (ref_shard == test_shard).all() + ) + for name in set(ref_shard) | set(test_shard): + assert (ref_shard[name] == test_shard[name]).all() @pytest.mark.slow @@ -521,12 +524,13 @@ def test_load_pretrained_huggingface_in_dp2(): / "checkpoint" / "1" / f"rank_{rank}.safetensors" - )["state_shard"] + ) test_shard = safetensors.torch.load_file( TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_huggingface_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" - )["state_shard"] - assert (ref_shard == test_shard).all() + ) + for name in set(ref_shard) | set(test_shard): + assert (ref_shard[name] == test_shard[name]).all() From e086908b9fcb96b675e2c310d7347c169f8b8696 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 17 Mar 2025 23:15:37 -0400 Subject: [PATCH 11/16] fixes --- fast_llm/engine/multi_stage/fsdp.py | 32 ++++++++++------------ fast_llm/engine/multi_stage/multi_stage.py | 6 ++-- fast_llm/engine/multi_stage/stage.py | 2 ++ fast_llm/functional/linear.py | 6 ++-- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 7c8344ef..d45566fc 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -218,17 +218,15 @@ def setup( # Precompute the buffer slice for each parameter. # Use `.data` to hide the restore ops from autograd. self._parameter_buffers = {} - for weight_buffer, grad_buffer, parameter_name in zip( - self.split_buffer(self._weight_buffer.data).values(), - self.split_buffer( - self._grad_buffer if self._mode.support_backward else self._grad_buffer_meta - ).values(), - self._parameter_metas, - ): - parameter_buffer = torch.nn.Parameter(weight_buffer, requires_grad=self._mode.support_backward) - if self._mode.support_backward: - parameter_buffer.grad_buffer = grad_buffer - # TODO: This is only needed for Megatron initialization + for parameter_name in self._parameter_metas: + parameter_buffer = torch.nn.Parameter( + self._get_parameter_in_buffer(self._weight_buffer.data, parameter_name), + requires_grad=self._mode.support_backward and self._parameter_metas[parameter_name].requires_grad, + ) + if self._mode.support_backward and self._requires_grad: + parameter_buffer.grad_buffer = self._get_parameter_in_buffer( + self._grad_buffer.data, parameter_name + ) self._parameter_buffers[parameter_name] = parameter_buffer def reset_shard_pad(self, shard: torch.Tensor) -> int: @@ -245,12 +243,12 @@ def reset_shard_pad(self, shard: torch.Tensor) -> int: def split_buffer(self, buffer: torch.Tensor) -> dict[str, torch.Tensor]: # Split a buffer into appropriately shaped parameters. - return { - name: buffer[self.get_parameter_begin_in_buffer(name) : self.get_parameter_end_in_buffer(name)].view( - meta.shape - ) - for name, meta in self._parameter_metas.items() - } + return {name: self._get_parameter_in_buffer(buffer, name) for name in self._parameter_metas} + + def _get_parameter_in_buffer(self, buffer: torch.Tensor, name: str) -> torch.Tensor: + return buffer[self.get_parameter_begin_in_buffer(name) : self.get_parameter_end_in_buffer(name)].view( + self._parameter_metas[name].shape + ) def split_shard(self, shard: torch.Tensor) -> dict[str, torch.Tensor]: # Split a shard into flat (possibly empty) parameter slices. diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index c38a46b0..ec1dd242 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -313,8 +313,8 @@ def _setup_stages( ) stage_grad_shards = ( self._shards[ShardName.grads] - .split(self._stage_weight_shard_sizes)[shard_index] - .split(self._fsdp_weight_shard_sizes[shard_index]) + .split(self._stage_grad_shard_sizes)[shard_index] + .split(self._fsdp_grad_shard_sizes[shard_index]) if self._mode.support_backward and shard_index is not None else None ) @@ -354,7 +354,7 @@ def get_param_groups( ( self._fsdp_weight_shard_sizes if shard_name == ShardName.weights - else self._stage_grad_shard_sizes + else self._fsdp_grad_shard_sizes )[shard_index] ) for shard_name, shard_split in optimizer_shards_split.items() diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 8d343125..a60fafd3 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -66,6 +66,8 @@ def setup( # noqa with torch.enable_grad(): for meta in self._parameter_metas: buffer = self.get_parameter_buffer(meta.tensor_name) + if not buffer.requires_grad: + continue # We want to replace the grad accumulation function with ours, but pytorch won't let us do that. # Instead, we let a trivial accumulation run its course (sets .grad), # then run the actual accumulation. diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index d583d1a9..dbc05184 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -42,7 +42,9 @@ def update_linear_gradients( input_ = input_.flatten(0, -2) lhs, rhs = (input_.t(), grad_output) if transposed_weight else (grad_output.t(), input_) - if TritonConfig.TRITON_LINEAR or sparse_map is not None: + if not weight.requires_grad: + pass + elif TritonConfig.TRITON_LINEAR or sparse_map is not None: # This assumes the transposed_weight is True for input_sparse, False for output_sparse. input_row_sparse_matmul( lhs, @@ -63,7 +65,7 @@ def update_linear_gradients( ) else: accumulate_gradient(weight, torch.mm(lhs, rhs)) - if bias is not None: + if bias is not None and bias.requires_grad: accumulate_gradient(bias, grad_output.sum(dim=0)) From 59c1f8d83e902a5d3db4424d93af723e028207f7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 18 Mar 2025 16:49:45 -0400 Subject: [PATCH 12/16] fix --- fast_llm/engine/multi_stage/stage_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4eb70f62..b24f9720 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -236,6 +236,8 @@ def get_param_groups( grouped_parameter_slices = {} param_groups = [] for i, fsdp in enumerate(self._fsdps): + if not fsdp.requires_grad: + continue for parameter_name in fsdp.parameter_names: # If needed, chunk the parameter on the first dimension. parameter_meta = fsdp.get_parameter_meta(parameter_name) From cc192d59031da27f8eedd6949887bd185e033582 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 18 Mar 2025 17:52:32 -0400 Subject: [PATCH 13/16] fix --- fast_llm/engine/multi_stage/stage_base.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index b24f9720..4885e516 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -51,6 +51,11 @@ def __init__( parameter_metas, frozen_metas = self._get_parameter_metas() self._parameter_metas = parameter_metas + frozen_metas self._fsdps = [] + gradient_buffer_dtype = ( + self._distributed_config.optimization_dtype + if self._config.full_precision_gradients + else self._distributed_config.training_dtype + ) if parameter_metas: self._fsdps.append( FSDP( @@ -58,11 +63,7 @@ def __init__( parameter_metas, self._distributed_config.get_distributed_dim(DistributedDimNames.data), training_dtype=self._distributed_config.training_dtype, - gradient_buffer_dtype=( - self._distributed_config.optimization_dtype - if self._config.full_precision_gradients - else self._distributed_config.training_dtype - ), + gradient_buffer_dtype=gradient_buffer_dtype, optimization_dtype=self._distributed_config.optimization_dtype, ) ) @@ -73,7 +74,7 @@ def __init__( frozen_metas, self._distributed_config.get_distributed_dim(DistributedDimNames.data), training_dtype=self._distributed_config.training_dtype, - gradient_buffer_dtype=self._distributed_config.training_dtype, + gradient_buffer_dtype=gradient_buffer_dtype, optimization_dtype=( self._distributed_config.optimization_dtype if self._config.store_frozen_weights_in_optimization_precision From e878656e5b7414a2ce20d21c55eb506ccf9f1ddd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 18 Mar 2025 18:20:10 -0400 Subject: [PATCH 14/16] Add test --- fast_llm/engine/multi_stage/multi_stage.py | 24 +++++------ tests/test_multi_stage.py | 48 ++++++++++++++++++++++ 2 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 tests/test_multi_stage.py diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index ec1dd242..238bd865 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -36,6 +36,8 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): _flat_shard: torch.Tensor _shards: dict[str, torch.Tensor] _shard_names: tuple[str, ...] + _weight_buffers: tuple[torch.Tensor, ...] + _grad_buffers: tuple[torch.Tensor, ...] _distributed: Distributed _mode: StageMode @@ -221,15 +223,13 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) # Allocate and split shards and buffers. if self._mode.support_forward: - weight_buffers, mem = self._allocate_buffers(self._weight_buffer_meta, self._weight_buffer_sizes, "weight") + self._weight_buffers, mem = self._allocate_buffers( + self._weight_buffer_meta, self._weight_buffer_sizes, "weight" + ) allocated += mem - else: - weight_buffers = None if self._mode.support_backward: - grad_buffers, mem = self._allocate_buffers(self._grad_buffer_meta, self._grad_buffer_sizes, "grad") + self._grad_buffers, mem = self._allocate_buffers(self._grad_buffer_meta, self._grad_buffer_sizes, "grad") allocated += mem - else: - grad_buffers = None self._shard_names = () if self._mode.on_device: @@ -250,7 +250,7 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) tied_parameter.setup(self._distributed) # Setup the layer shards and buffers. - self._setup_stages(weight_buffers, grad_buffers) + self._setup_stages() self.train(self._mode.support_backward) @@ -283,22 +283,22 @@ def _allocate_shards(self) -> int: } return mem - def _setup_stages( - self, weight_buffers: tuple[torch.Tensor, ...] | None, grad_buffers: tuple[torch.Tensor, ...] | None - ) -> None: + def _setup_stages(self) -> None: for stage_index, stage in enumerate(self._stages): shard_index = self._stage_shard_indices.get(stage_index) weight_buffer_index = self._weight_buffer_indices.get(stage_index) grad_buffer_index = self._grad_buffer_indices.get(stage_index) stage_weight_buffers = ( - weight_buffers[weight_buffer_index][: self._stage_weight_buffer_sizes[stage_index]].split( # noqa + self._weight_buffers[weight_buffer_index][ + : self._stage_weight_buffer_sizes[stage_index] + ].split( # noqa self._fsdp_weight_buffer_sizes[stage_index] ) if self._mode.support_forward and weight_buffer_index is not None else None ) stage_grad_buffers = ( - grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[stage_index]].split( # noqa + self._grad_buffers[grad_buffer_index][: self._stage_grad_buffer_sizes[stage_index]].split( # noqa self._fsdp_grad_buffer_sizes[stage_index] ) if self._mode.support_backward and grad_buffer_index is not None diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py new file mode 100644 index 00000000..42e74513 --- /dev/null +++ b/tests/test_multi_stage.py @@ -0,0 +1,48 @@ +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.tools.train import CliTrainingConfig +from fast_llm.utils import Assert +from tests.common import CONFIG_COMMON + + +def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: + parsed, unparsed = CliTrainingConfig._get_parser().parse_known_args([model_type] + args) + config: TrainerConfig = CliTrainingConfig._from_parsed_args(parsed, unparsed) + distributed = Distributed(config.model.distributed) + trainer = config.get_trainer_class()(config=config) + trainer.setup(distributed, config.get_run(distributed)) + return trainer + + +def test_frozen_weights(): + args = CONFIG_COMMON + ["run.tensor_logs.save=False"] + model_ref = _get_trainer_from_args(args)._multi_stage + model_frozen = _get_trainer_from_args(args + ["model.base_model.transformer.mlp_lr_scale=[0]"])._multi_stage + + Assert.eq( + model_ref._num_stages, + model_frozen._num_stages, + ) + diff_by_layer = [ + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, TransformerLayer) else 0 + for layer in model_ref.base_model.layers + ] + assert all((diff_by_layer[i] == 0) == (i in (0, len(diff_by_layer) - 1)) for i in range(len(diff_by_layer))) + total_diff = sum(diff_by_layer) + + for weight_buffer_ref, weight_buffer_frozen in zip( + model_ref._weight_buffers, model_frozen._weight_buffers, strict=True + ): + assert weight_buffer_ref.numel() == weight_buffer_frozen.numel() + + for grad_buffer_ref, grad_buffer_frozen, diff in zip( + model_ref._grad_buffers, model_frozen._grad_buffers, diff_by_layer, strict=True + ): + Assert.eq(grad_buffer_ref.numel() - grad_buffer_frozen.numel() == diff) + + for shard_name, shard_diff in zip( + model_ref._shard_names, [0] + [total_diff] * (len(model_ref._all_shard_names) - 1), strict=True + ): + Assert.eq(model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), shard_diff) From 3f79798576c879eaa75259cc233515e29c0e211d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 18 Mar 2025 19:09:44 -0400 Subject: [PATCH 15/16] fix --- tests/test_multi_stage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 42e74513..bb468ceb 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert -from tests.common import CONFIG_COMMON +from tests.common import CONFIG_COMMON, requires_cuda def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @@ -16,6 +16,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: return trainer +@requires_cuda def test_frozen_weights(): args = CONFIG_COMMON + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args)._multi_stage From f9f288382a13dae653f727c3290496aa0a8c79ad Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 21 Mar 2025 18:53:42 -0400 Subject: [PATCH 16/16] Add warning --- fast_llm/engine/checkpoint/distributed.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index f0bb900e..0ded53ba 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -59,7 +59,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) if same_format: - log_main_rank("Checkpoint format matches, using fast load") + log_main_rank("Checkpoint format matches, using fast load", log_fn=logger.info) # TODO: Add version without optimizer state? with safetensors.safe_open( config.path / f"rank_{self._model.config.distributed.rank}.safetensors", @@ -69,6 +69,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No if "state_shard" in f.keys(): # Old format `state_shard` with shape `(num_shards, shard_size) # TODO v0.3: Use checkpoint version? Drop support? + log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) for shard_name in shard_names: self._model.get_shard(shard_name).copy_( f.get_slice("state_shard")[metadata.shards.index(shard_name)] @@ -79,7 +80,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard")) else: - log_main_rank("Checkpoint format doesn't match, using safe load") + log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info) self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context: for rank in range(loaded_config.distributed.world_size): @@ -89,13 +90,14 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No verbose=False, ) path = config.path / f"rank_{rank}.safetensors" - log_main_rank(f"Loading from {path}") + log_main_rank(f"Loading from {path}", log_fn=logger.info) # TODO: skip shards without overlap. with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard if "state_shard" in f.keys(): # Old format `state_shard` with shape `(num_shards, shard_size) # TODO v0.3: Use checkpoint version? Drop support? + log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning()) loaded_shards = { shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)] for shard_name in shard_names