Skip to content

Make the specified config parameters update the pretrained config #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5137757
stuff
jlamypoirier Mar 26, 2025
f0cb32a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Mar 26, 2025
f26010e
Update pretrained config
jlamypoirier Mar 27, 2025
b930a39
stuff
jlamypoirier Mar 27, 2025
918a7a8
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
8117c47
fixes
jlamypoirier Mar 27, 2025
1c995d3
fix
jlamypoirier Mar 27, 2025
3f90475
Merge branch 'main' into config_updates
jlamypoirier Mar 27, 2025
e389058
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
506fe92
fixes
jlamypoirier Mar 27, 2025
971d3ef
fixes
jlamypoirier Mar 27, 2025
6bf20cb
Tests wip
jlamypoirier Mar 28, 2025
c13fb19
misc
jlamypoirier Mar 29, 2025
a20fcec
tests
jlamypoirier Apr 1, 2025
9af26a7
Merge branch 'main' into config_updates
jlamypoirier Apr 1, 2025
9af372d
Tests, fixes, remove tuple format
jlamypoirier Apr 1, 2025
dded00a
fix
jlamypoirier Apr 2, 2025
42d5ca4
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 2, 2025
986f9f3
fix
jlamypoirier Apr 2, 2025
5abc087
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 2, 2025
8e3e795
fixes
jlamypoirier Apr 2, 2025
da6eb7b
fixes
jlamypoirier Apr 3, 2025
67e08aa
Merge branch 'main' into config_updates
jlamypoirier Apr 3, 2025
a09e6f3
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 3, 2025
baad705
fix
jlamypoirier Apr 3, 2025
b702837
Test, fixes
jlamypoirier Apr 5, 2025
7c2933a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 14, 2025
a017c11
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 14, 2025
368a6bf
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 14, 2025
cff9892
fixes
jlamypoirier Apr 14, 2025
48141e5
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 17, 2025
e6701aa
Merge branch 'main' into update_pretrained_config
tscholak Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf
@config_class()
class CheckpointLoadMetadataConfig(CheckpointPathConfigBase):
_abstract = False

# TODO: Set default to model? (Not backward compatible)
load_config: ModelConfigType = Field(
default=ModelConfigType.architecture,
desc="Configuration to save/load.",
Expand All @@ -213,10 +213,6 @@ def _validate(self) -> None:
if self.format.enforce_architecture_match:
assert self.load_config.load_architecture

@property
def compare_log_fn(self):
return ValueError if self.load_config.load_architecture else logger.warning


@config_class()
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
Expand All @@ -237,19 +233,37 @@ class CheckpointHandler(abc.ABC):
def __init__(self, model: "FastLLMModel"):
self._model = model

# TODO: save_metadata?

@classmethod
@abc.abstractmethod
def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: "CheckpointMetadata"):
pass

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
updates = {}
metadata = cls._load_metadata(config)
if not config.load_config.load_fast_llm:
updates[("config", "multi_stage")] = {}
updates[("config", "distributed")] = {}
if not config.load_config.load_architecture:
updates[("config", "base_model")] = {}
elif not config.load_config.load_base_model:
updates[("config", "base_model")] = metadata.config.base_model.get_architecture().to_dict()
if updates:
metadata = metadata.to_copy(updates)
return metadata

@classmethod
@abc.abstractmethod
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
pass

@abc.abstractmethod
def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"):
pass

@abc.abstractmethod
def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"):
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
pass

def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]:
Expand Down
32 changes: 20 additions & 12 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CheckpointLoadConfig,
CheckpointLoadMetadataConfig,
CheckpointSaveConfig,
CheckpointSaveMetadataConfig,
DistributedCheckpointFormat,
ModelConfigType,
export_safetensors_metadata,
Expand All @@ -28,7 +29,13 @@ class DistributedCheckpointHandler(CheckpointHandler):
format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata):
config.path.mkdir(parents=True, exist_ok=True)
serialized_metadata = metadata.to_dict()
yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w"))

@classmethod
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r")))

def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
Expand All @@ -41,17 +48,16 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
metadata=export_safetensors_metadata(serialized_metadata),
)

def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
# TODO: More safety checks
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm}))
shard_names = self.get_shard_names(config)
# Make sure all shards to load are in the checkpoint.
Assert.leq(set(self.get_shard_names(config)), set(metadata.shards))
Assert.eq(metadata.shards[: len(shard_names)], list(shard_names))
Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards))
Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names))

# Using `log_fn=bool` sets the output to true if the error list is non-empty.
same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool)
same_format = config.optimizer_state and not loaded_metadata.config.compare(self._model.config, log_fn=bool)
# Make sure all nodes agree on which loading scheme to use.
# Note: they may not agree before the broadcast because of the rank comparison, but that's ok.
same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group)
Expand All @@ -70,7 +76,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning)
for shard_name in shard_names:
self._model.get_shard(shard_name).copy_(
f.get_slice("state_shard")[metadata.shards.index(shard_name)]
f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)]
)
else:
# TODO: Does this copy twice?
Expand All @@ -79,11 +85,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No

else:
log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info)
self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn)
self._model.config.base_model.compare_architecture(loaded_metadata.config.base_model, logger.warning)
with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context:
for rank in range(loaded_config.distributed.world_size):
for rank in range(loaded_metadata.config.distributed.world_size):
loaded_model = self._model.__class__(
loaded_config.to_copy({("distributed", "rank"): rank}),
loaded_metadata.config.to_copy({("distributed", "rank"): rank}),
optimizer_state_names=shard_names[1:],
verbose=False,
)
Expand All @@ -97,7 +103,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
# TODO v0.3: Use checkpoint version? Drop support?
log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning)
loaded_shards = {
shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)]
shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)]
for shard_name in shard_names
}
else:
Expand All @@ -122,3 +128,5 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
)

context.mark_as_loaded(counter.item())

return loaded_metadata.metadata
2 changes: 1 addition & 1 deletion fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self, model: "FastLLMModel"):
}

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
imported_model_config = cls._import_config(cls._load_config(config.path), True)
return CheckpointMetadata(
fast_llm_version=__version__,
Expand Down
13 changes: 8 additions & 5 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC):

def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
path = config.path / f"{self.base_file_name}.safetensors.index.json"
@classmethod
def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
config.path.mkdir(parents=True, exist_ok=True)
path = config.path / f"{cls.base_file_name}.safetensors.index.json"
logger.info(f"Saving index to {path}")
# Save the index.
json.dump(
Expand All @@ -41,10 +43,11 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch
"format": "pt",
}

def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
assert not config.optimizer_state
self._model.config.base_model.compare_architecture(metadata.config.base_model, config.compare_log_fn)
super().load(config, metadata)
metadata = self._model.config.load_metadata(config)
self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning)
super().load(config)

@classmethod
def get_huggingface_model_type(self) -> str:
Expand Down
21 changes: 16 additions & 5 deletions fast_llm/engine/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
class StateDictCheckpointHandler(CheckpointHandler):
base_file_name: typing.ClassVar[str] = "model"

@classmethod
def save_metadata(
cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata, index: dict | None = None
):
serialized_metadata = cls._serialize_metadata(config, metadata)
cls._save_serialized_metadata(config, serialized_metadata, {} if index is None else index)

def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
serialized_metadata = self._serialize_metadata(config, metadata)
saver = StateDictSaver(
Expand Down Expand Up @@ -64,16 +71,18 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
if self._model.config.distributed.rank == 0:
self._save_serialized_metadata(config, serialized_metadata, index)

@classmethod
@abc.abstractmethod
def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
pass

@classmethod
def _serialize_metadata(
self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata
cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata
) -> dict[str, typing.Any]:
return metadata.to_dict()

def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context:
# The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from
# `state_dict` that are ready for conversion,
Expand Down Expand Up @@ -116,14 +125,16 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler):
format: typing.ClassVar[type[CheckpointFormat]] = FastLLMCheckpointFormat

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
path = config.path / f"metadata.yaml"
logger.warning(f"Loading metadata from {path}")
return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r")))

@classmethod
def _save_serialized_metadata(
self, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict
cls, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict
) -> None:
config.path.mkdir(parents=True, exist_ok=True)
path = config.path / f"metadata.yaml"
logger.info(f"Saving metadata to {path}")
if "metadata" not in serialized_metadata:
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def from_pretrained(
format=FastLLMCheckpointFormat,
)

config_updates = {}
updates = {}
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
config_updates[("distributed", "training_dtype")] = torch_dtype
updates[("distributed", "training_dtype")] = torch_dtype

# Create the model
fast_llm_model = cls.runner_class.model_class.from_pretrained(
pretrained_model_name_or_path, config_updates=config_updates, mode=mode
pretrained_model_name_or_path, updates, mode=mode
)
config = cls.config_class(fast_llm_model.config)

Expand Down
61 changes: 16 additions & 45 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Field,
FieldHint,
NoAutoValidate,
UpdateType,
ValidationError,
check_field,
config_class,
Expand Down Expand Up @@ -186,11 +187,12 @@ class MultiStageConfig(StageConfig):
def _validate(self) -> None:
super()._validate()
if self.zero_stage is not None:
Assert.in_range_incl(self.zero_stage, 1, 3)
if self.zero_stage >= 2:
self.num_grad_buffers = 2
if self.zero_stage >= 3:
self.num_weight_buffers = 2
with self._set_implicit_default():
Assert.in_range_incl(self.zero_stage, 1, 3)
if self.zero_stage >= 2:
self.num_grad_buffers = 2
if self.zero_stage >= 3:
self.num_weight_buffers = 2
if self.num_grad_buffers is not None:
Assert.geq(self.num_grad_buffers, 1)
if self.num_weight_buffers is not None:
Expand Down Expand Up @@ -254,49 +256,13 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]:

@classmethod
def from_pretrained(
cls,
pretrained: CheckpointLoadMetadataConfig,
default: typing.Self | None = None,
) -> typing.Self:
# TODO: Add *updates?
assert pretrained.path is not None
metadata = cls.load_metadata(pretrained)
return cls.from_metadata(pretrained, metadata, default)

@classmethod
def from_metadata(
cls,
pretrained: CheckpointLoadMetadataConfig,
metadata: "CheckpointMetadata",
default: typing.Self | None = None,
updates: dict[str | tuple[str, ...], typing.Any] | None = None,
cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any]
) -> typing.Self:
# TODO: Standardize to *updates?
# TODO v0.3: Update, remove support for older checkpoints.
if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2):
raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}")
pretrained_config = cls.from_dict(metadata.config)
if not pretrained.load_config.load_architecture:
assert default is not None
config = default.to_copy()
config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn)
elif pretrained.load_config.load_fast_llm:
config = pretrained_config
else:
with NoAutoValidate():
config = cls() if default is None else default.to_copy()
if pretrained.load_config.load_base_model:
config.base_model = pretrained_config.base_model
else:
config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture())
config.validate()

if updates:
config = config.to_copy(updates)
return config
return cls.from_dict(cls.load_metadata(pretrained).config, *updates, update_type=UpdateType.update)

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
assert config.path is not None
with NoAutoValidate():
metadata = config.format.get_handler_class().load_metadata(config)
try:
Expand All @@ -316,6 +282,9 @@ def to_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> "Checkp
**kwargs,
)

def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None:
self.get_checkpoint_handler_class(config.format).save_metadata(config, self.to_metadata(config, **kwargs))


@config_class()
class PretrainedFastLLMModelConfig(Config):
Expand All @@ -336,7 +305,7 @@ def _validate(self) -> None:
self.pretrained.setup(self.model)
self.pretrained.validate()
if self.pretrained.path is not None:
self.model = self.model.from_pretrained(self.pretrained, default=self.model)
self.model = self.model.from_pretrained(self.pretrained, self.model)
self._setup()
super()._validate()

Expand Down Expand Up @@ -388,6 +357,8 @@ def _validate(self) -> None:

self.format = self.model.get_checkpoint_format(self.format)
super()._validate()
if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2):
raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}")
Assert.eq(self.config.__class__, self.model)

@classmethod
Expand Down
Loading