Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support frozen weights #185

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
3 changes: 0 additions & 3 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,5 @@ def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"):
def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"):
pass

def get_num_shards(self, config: CheckpointStateConfigBase) -> int:
return len(self._model.state_shard_names) if config.optimizer_state else 1

def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]:
return self._model.state_shard_names if config.optimizer_state else self._model.state_shard_names[:1]
62 changes: 40 additions & 22 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
if self._model.config.distributed.rank == 0:
yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w"))
safetensors.torch.save_file(
tensors={"state_shard": self._model.state_shard[: self.get_num_shards(config)]},
tensors={f"{shard_name}_shard": self._model.get_shard(shard_name) for shard_name in metadata.shards},
filename=config.path / f"rank_{self._model.config.distributed.rank}.safetensors",
metadata=export_safetensors_metadata(serialized_metadata),
)
Expand All @@ -45,9 +45,10 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
# TODO: More safety checks
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
num_shards = self.get_num_shards(config)
shard_names = self.get_shard_names(config)
Assert.eq(metadata.shards[:num_shards], list(shard_names))
# Make sure all shards to load are in the checkpoint.
Assert.leq(set(self.get_shard_names(config)), set(metadata.shards))
Assert.eq(metadata.shards[: len(shard_names)], list(shard_names))

same_format = (
loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None)
Expand All @@ -65,12 +66,22 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
framework="pt",
device=str(self._model.distributed.device),
) as f:
# TODO: Does this copy twice?
self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards])
if "state_shard" in f.keys():
# Old format `state_shard` with shape `(num_shards, shard_size)
# TODO v0.3: Use checkpoint version? Drop support?
for shard_name in shard_names:
self._model.get_shard(shard_name).copy_(
f.get_slice("state_shard")[metadata.shards.index(shard_name)]
)
else:
# TODO: Does this copy twice?
for shard_name in shard_names:
self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard"))

else:
log_main_rank("Checkpoint format doesn't match, using safe load")
self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn)
with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context:
with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context:
for rank in range(loaded_config.distributed.world_size):
loaded_model = self._model.__class__(
loaded_config.to_copy({("distributed", "rank"): rank}),
Expand All @@ -82,25 +93,32 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
# TODO: skip shards without overlap.
with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f:
# TODO: Use self_shard
loaded_shard = f.get_slice("state_shard")[:num_shards]
loaded_model.state_shard_meta.validate(loaded_shard)
if "state_shard" in f.keys():
# Old format `state_shard` with shape `(num_shards, shard_size)
# TODO v0.3: Use checkpoint version? Drop support?
loaded_shards = {
shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)]
for shard_name in shard_names
}
else:
loaded_shards = {
shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names
}

# TODO: Improve num shard selection.
self_shard_split = self._model.state_shard[: loaded_shard.size(0)].split(
self._model.stage_shard_sizes, 1
)
loaded_shard_split = loaded_shard.split(loaded_model.stage_shard_sizes, 1)
for shard_name, loaded_shard in loaded_shards.items():
loaded_model.get_shard_meta(shard_name).validate(loaded_shard)

self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names}

counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device)
for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()):
loaded_shards = (
loaded_shard_split[loaded_shard_index].to(self._model.distributed.device).unbind(0)
)
for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()):
self_stage._copy_shard_overlaps( # noqa
loaded_stage,
self_shard_split[self_shard_index].unbind(0),
loaded_shards,
for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards):
for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards):
self_fsdp.copy_shard_overlaps(
loaded_fsdp,
self_fsdp_shards,
loaded_fsdp_shards,
counter,
self._model.distributed.device,
)

context.mark_as_loaded(counter.item())
49 changes: 25 additions & 24 deletions fast_llm/engine/checkpoint/safe_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ class SafeLoad:
In case of failure, it will attempt to find out as precisely as possible where the problem comes from.
"""

def __init__(self, model: "FastLLMModel", *, num_shards: int, timeout: float | None = None):
def __init__(self, model: "FastLLMModel", *, shard_names: tuple[str, ...], timeout: float | None = None):
self._model = model
self._distributed = self._model.distributed
self._num_shards = num_shards
self._self_shard = self._model.state_shard[: self._num_shards]
# self._num_shards = num_shards
self._self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names}
self._timeout = timeout

def __enter__(self) -> "SafeLoad":
self._loaded = 0
self._loaded_parameters = {}
# Track the number of loaded entries.
# Use nan to mark non-loaded entries.
triton_fill(self._self_shard, math.nan)
for self_shard in self._self_shards.values():
triton_fill(self_shard, math.nan)
# Reset and count shard pads
for shard in self._model.state_shard[: self._num_shards]:
shard_split = shard.split(self._model.stage_shard_sizes, 0)
for stage, stage_shard in zip(self._model.stages_on_device.values(), shard_split):
self._loaded += stage.reset_shard_pad(stage_shard)
for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards):
for fsdp_shard in fsdp_shards.values():
self._loaded += fsdp.reset_shard_pad(fsdp_shard)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -70,18 +70,19 @@ def _validate(self) -> None:
logger.info(f"{self._loaded:,} state entries loaded successfully")

def _check_counter(self, errors: list[str]) -> None:
to_load = self._self_shard.numel()
to_load = sum(self_shard.numel() for self_shard in self._self_shards.values())
if self._loaded != to_load:
# Ensure the right amount of weights is loaded.
errors.append(f"Loaded a total of {self._loaded:,}, state entries, expected {to_load:,}")

def _check_missing(self, errors: list[str]) -> None:
# Ensure the loaded weights have a 1-1 mapping by looking for nans.
missing = self._self_shard.new_zeros([], dtype=torch.int64)
missing = torch.zeros([], dtype=torch.int64, device=self._distributed.device)
# Count nans in slices of 100M parameters to limit memory usage.
# TODO: Find better solution (triton kernel?)
for shard_slice in self._self_shard.flatten().split(100000000):
missing += shard_slice.isnan().sum()
for shard in self._self_shards.values():
for shard_slice in shard.flatten().split(100000000):
missing += shard_slice.isnan().sum()
local_missing = missing.item()
if self._distributed.world_group is not None:
all_reduce(missing, group=self._distributed.world_group)
Expand All @@ -90,32 +91,32 @@ def _check_missing(self, errors: list[str]) -> None:
errors.append(f"{global_missing:,} state entries failed to load or corrupted (local={local_missing:,}).")
# Determine where the missing values are coming from.
global_total, local_total = 0, 0
for shard_name, shard_ in zip(self._model.state_shard_names[: self._num_shards], self._self_shard):
shard_split = shard_.split(self._model.stage_shard_sizes, 0)
for stage, shard in zip(self._model.stages_on_device.values(), shard_split):
buffer = stage._reconstruct_from_shard(shard)
for i, parameter in enumerate(stage._split_buffer(buffer)):
for stage, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards):
for shard_name, fsdp_shard in fsdp_shards.items():
buffer = fsdp.reconstruct_from_shard(fsdp_shard)
for parameter_name, parameter in fsdp.split_buffer(buffer).items():
missing_for_param = parameter.isnan().sum().item()
if missing_for_param > 0:
global_total += missing_for_param
local_values = stage._split_shard(shard)[i]
local_values = fsdp.split_shard(fsdp_shard)[parameter_name]
local_missing_for_param = local_values.isnan().sum().item()
local_total += local_missing_for_param
errors.append(
f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {stage.parameter_names[i]} in stage {stage.index}, shard {shard_name}"
f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}"
f" (locally {local_missing_for_param:,} out of {local_values.numel():,})"
)
missing_for_pad = buffer[-stage._global_pad :].isnan().sum().item()
missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item()
if missing_for_pad > 0:
global_total += missing_for_pad
local_missing_for_pad = (
shard[-stage._shard_pad :].isnan().sum().item() if stage._shard_pad > 0 else 0
fsdp_shard[-fsdp._shard_pad :].isnan().sum().item() if fsdp._shard_pad > 0 else 0
)
local_total += local_missing_for_pad
errors.append(
f"{missing_for_pad:,} values missing out of {stage._global_pad:,} for padding in stage {stage.index}, shard {shard_name}"
f" (locally {local_missing_for_pad:,} out of {stage._shard_pad:,})"
f"{missing_for_pad:,} values missing out of {fsdp._global_pad:,} for padding in stage {stage.index}, shard {shard_name}"
f" (locally {local_missing_for_pad:,} out of {fsdp._shard_pad:,})"
)

if global_total != global_missing:
errors.append(
f"Incorrect global breakdown of missing state entries (expected {global_missing:,}, got {global_total:,})"
Expand All @@ -127,7 +128,7 @@ def _check_missing(self, errors: list[str]) -> None:

def _check_parameters(self, errors: list[str]) -> None:
loaded_shard_names = set(self._loaded_parameters)
shard_names = set(self._model.state_shard_names[: self._num_shards])
shard_names = set(self._self_shards)
if loaded_shard_names != shard_names:
errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}")
for shard_name in shard_names & loaded_shard_names:
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/engine/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _serialize_metadata(
return metadata.to_serialized()

def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
with SafeLoad(self._model, num_shards=self.get_num_shards(config), timeout=config.timeout) as context:
with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context:
# The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from
# `state_dict` that are ready for conversion,
# and return a dict containing the converted tensors(s).
Expand Down Expand Up @@ -145,7 +145,7 @@ def _load_weights(
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
metadata = self.load_metadata(config)
shard_names = self.get_shard_names(config)
Assert.eq(metadata.shards[: self.get_num_shards(config)], list(shard_names))
Assert.leq(set(shard_names), set(metadata.shards))
for file_name in set(metadata.metadata["state_index"].values()):
logger.info(f"Loading from {config.path / file_name}")
with safetensors.safe_open(
Expand Down
8 changes: 8 additions & 0 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ class StageConfig(Config):
desc="Reduce and accumulate gradients in fp32 to improve numerical stability.",
hint=FieldHint.optional,
)
store_frozen_weights_in_optimization_precision: bool = Field(
# TODO: Implement and set default to False
default=True,
desc="Store frozen weights in full precision even if not not needed."
"Allows preserving the precision for saved checkpoints,"
" at the cost of memory and compute (copy) overheads.",
hint=FieldHint.optional,
)
debug_layer_outputs: int = Field(
default=0,
desc="Log the output of each layer.",
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def initialize_weights(self, timeout: float | None = None) -> None:

def _finalize_load(self, reset_optimizer: bool = True) -> None:
if reset_optimizer:
triton_fill(self._state_shard[1:], 0.0)
triton_fill(self._flat_shard[self._weight_shard_size :], 0.0)
if self._mode.support_forward:
self.invalidate_buffers()
self._is_loaded = True
Loading