From 513775730095c5e53e4d982139a250aba9c99b52 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 00:10:47 -0400 Subject: [PATCH 01/18] stuff --- fast_llm/config.py | 177 ++++++++++++++++++----- fast_llm/data/data/config.py | 11 +- fast_llm/data/data/gpt/config.py | 12 +- fast_llm/data/dataset/config.py | 26 ++-- fast_llm/data/dataset/gpt/config.py | 13 +- fast_llm/engine/checkpoint/config.py | 5 +- fast_llm/engine/checkpoint/external.py | 4 +- fast_llm/engine/distributed/config.py | 3 +- fast_llm/engine/schedule/config.py | 7 +- fast_llm/engine/training/config.py | 4 +- fast_llm/layers/language_model/config.py | 23 +-- fast_llm/layers/transformer/config.py | 94 ++++++------ fast_llm/profile.py | 4 +- fast_llm/utils.py | 34 ----- tests/data/common.py | 7 +- 15 files changed, 241 insertions(+), 183 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..326845f0 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import enum import logging @@ -9,7 +10,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value +from fast_llm.utils import Assert, Tag, get_type_name, header, log logger = logging.getLogger(__name__) @@ -43,6 +44,13 @@ class _ConfigDictFormat(str, enum.Enum): tuple = "tuple" +class UpdateType(str, enum.Enum): + # Override entries no matter what they contais. + override = "override" + # Override atomic entries and lists, but update dicts recursively by setting or overriding only the specified entries. + update = "update" + + class FieldHint: """ A label defined for each config field, to let the user and some methods know how important each field is. @@ -125,6 +133,9 @@ def __init__( # Should raise an Exception in case of failure, and return the validated value. # Run before the default validation (type check). valid: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + # Option to skip (postpone) instantiation of a `Config` field. + # Note: The config still needs to be instantiated for validation to succeed. + # auto_instantiate: bool = True, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, @@ -152,6 +163,7 @@ def __init__( self.doc = doc self.hint = hint self.valid = valid + # self.auto_instantiate = auto_instantiate class FieldUpdate(dict): @@ -254,7 +266,16 @@ def config_class(cls=None): def wrap(cls): Assert.custom(issubclass, cls, Config) - return _process_config_class(dataclasses.dataclass(cls)) + wrapped = _process_config_class(dataclasses.dataclass(cls)) + + wrapped_init = cls.__init__ + + def __init__(self, **kwargs): + wrapped_init(self, **kwargs) + self._explicit_fields = set(kwargs) + + cls.__init__ = __init__ + return wrapped # See if we're being called as @config_class or @config_class(). if cls is None: @@ -277,9 +298,17 @@ class Config: # We can't use @config_class on this one because it needs this class to be defined, so we assume this one is OK. __class_validated__: typing.ClassVar[bool] = True + # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False + # Keep track of whether an instance has been validated _validated: bool = Field(init=False, repr=False) + # Keep track of unknown fields so they can be reported during validation. _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + # Keep track of explicitly set fields to ensure they get serialized and used as config updates. + _explicit_fields: set[str] = Field(init=False, repr=False) + # Used within `_set_implicit_default` to set implicit defaults for fields + # without them being automatically added to `_explicit_fields`. + _setting_implicit_default: bool = Field(init=False, repr=False) def __post_init__(self): """ @@ -288,6 +317,7 @@ def __post_init__(self): and all post-processing should be done in `_validate` """ self._validated = False + self._setting_implicit_default = False if _AUTO_VALIDATE: self.validate() @@ -305,6 +335,12 @@ def __setattr__(self, key: str, value: typing.Any) -> None: f"Cannot set attribute `{key}`" f" in configuration class `{get_type_name(type(self))}` after validation." ) + elif not getattr(self, "_setting_implicit_default", True): + field = self.get_field(key) + if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + # Adding to explicit field list except within `_set_implicit_default` context + # and during dataclass initialization (`_setting_implicit_default` not yet set). + self._explicit_fields.add(key) super().__setattr__(key, value) def __delattr__(self, key: str) -> None: @@ -318,6 +354,12 @@ def __delattr__(self, key: str) -> None: ) super().__delattr__(key) + @contextlib.contextmanager + def _set_implicit_default(self): + self._setting_implicit_default = True + yield + self._setting_implicit_default = False + def validate[T](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only @@ -332,6 +374,7 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True + print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -344,16 +387,17 @@ def _validate(self) -> None: """ self._check_abstract() errors = [] - for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa - continue - value = getattr(self, name) - if value is DEFAULT: - # Replace the value with its default. - # We still need to validate because some fields have invalid defaults. - value = field.default - new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) - setattr(self, name, new_value) + with self._set_implicit_default(): + for name, field in self.fields(): + if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + continue + value = getattr(self, name) + if value is DEFAULT: + # Replace the value with its default. + # We still need to validate because some fields have invalid defaults. + value = field.default + new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) + setattr(self, name, new_value) for name in getattr(self, "_unknown_fields", {}): errors.append(f"Unknown field `{name}` in class {self._get_class_name()}") if errors: @@ -555,9 +599,8 @@ def _to_dict( return arg_dict - @classmethod def _add_field_to_args( - cls, + self, args: dict | list, name: str | None, field: Field | None, @@ -574,46 +617,48 @@ def _add_field_to_args( ): # Exclude class variables and derived fields unless requested explicitly. return - elif isinstance(value, Config): + explicit_field = ( + field is None + or name in self._explicit_fields + or (verbose is not None and verbose >= FieldHintImportance[field.hint]) + ) + if isinstance(value, Config): field_value = value._to_dict( verbose=verbose, all_fields=all_fields, format_=format_, serializable=serializable, ) + # Empty configs can safely be trimmed. + explicit_field = all_fields elif isinstance(value, (list, tuple, set)): field_value = {} if format_ == _ConfigDictFormat.tuple else [] for i, list_value in enumerate(value): - cls._add_field_to_args( + self._add_field_to_args( field_value, str(i), None, list_value, verbose, all_fields, format_, serializable ) elif isinstance(value, dict): field_value = {} for dict_name, dict_value in value.items(): - cls._add_field_to_args( + self._add_field_to_args( field_value, dict_name, None, dict_value, verbose, all_fields, format_, serializable ) - elif ( - verbose is not None - and field is not None - and FieldHintImportance[field.hint] > verbose - and value == field.default - ): - # Exclude unimportant default values. - return - else: + elif explicit_field: field_value = value if serializable: - field_value = cls._serialize_value(value) + field_value = self._serialize_value(value) if format_ == _ConfigDictFormat.tuple: field_value = {(): field_value} + else: + # Exclude unimportant (implicit or explicit) default values. + return if serializable: - name = cls._serialize_value(name) + name = self._serialize_value(name) if format_ == _ConfigDictFormat.tuple: args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) elif format_ == _ConfigDictFormat.nested: - if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or all_fields: + if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: if isinstance(args, dict): args[name] = field_value else: @@ -671,6 +716,7 @@ def from_dict( default: typing.Union["Config", dict[str, typing.Any]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, + update_type: UpdateType = UpdateType.override, ) -> typing.Self: if isinstance(default, Config): default = default._to_dict() @@ -678,7 +724,7 @@ def from_dict( if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) for keys, value in update.items(): - set_nested_dict_value(default, keys, value) + set_nested_dict_value(default, keys, value, update_type) return cls._from_dict(default, strict) @@ -712,10 +758,7 @@ def _from_dict( continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): - if flat: - out_arg_dict[name] = field.type._from_dict(default, False, True) - else: - out_arg_dict[name] = field.type._from_dict(default.pop(name, {}), strict) + out_arg_dict[name] = field.type._from_dict(default, False, True) elif name in default: out_arg_dict[name] = default.pop(name) else: @@ -916,3 +959,69 @@ def __init__(self, config: ConfigType, *args, **kwargs): @property def config(self) -> ConfigType: return self._config + + +def set_nested_dict_value[ + KeyType, ValueType +]( + d: dict[KeyType, ValueType], + keys: KeyType | tuple[KeyType, ...], + value: ValueType, + update_type: UpdateType = UpdateType.override, +) -> None: + if isinstance(keys, tuple): + for key in keys[:-1]: + d = d.setdefault(key, {}) + assert isinstance(d, dict) + key = keys[-1] + else: + key = keys + if update_type == UpdateType.override: + d[key] = value + elif update_type == UpdateType.update: + # TODO: Improve error messages, ex. for nested cases? + if isinstance(d[key], Config): + raise ValueError("Cannot update an already instantiated config.") + elif isinstance(value, Config): + raise ValueError("Cannot update a config dict with an already instantiated config.") + elif isinstance(d, dict): + if key in d: + Assert.custom(isinstance, d[key], dict) + else: + d[key] = {} + for key_, value_ in value.items(): + set_nested_dict_value(d, key_, value_, update_type) + elif ( + isinstance(value, (list, set, tuple)) + and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) + ) or ( + isinstance(d[key], (list, set, tuple)) + and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) + ): + raise ValueError("Update not supported for nested lists.") + else: + d[key] = value + else: + raise NotImplementedError(update_type) + + +def get_nested_dict_value[ + KeyType, ValueType +](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: + if isinstance(keys, tuple): + for key in keys: + d = d[key] + return d + else: + return d[keys] + + +def pop_nested_dict_value[ + KeyType, ValueType +](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: + if isinstance(keys, tuple): + for key in keys[:-1]: + d = d[key] + return d.pop(keys[-1]) + else: + return d.pop(keys) diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 752fdfd1..25850ac3 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -1,18 +1,9 @@ import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Config, Field, config_class from fast_llm.data.dataset.config import SamplingConfig, SamplingData -@config_class() -class SamplingDefaultConfig(SamplingConfig): - seed: int = FieldUpdate( - default=784569, - desc="Seed for random sampling.", - hint=FieldHint.feature, - ) - - @config_class() class DataConfig(Config): _abstract = True diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index cbbfa036..d1d6bd40 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -2,13 +2,12 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig -from fast_llm.data.data.config import DataConfig, SamplingDefaultConfig +from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.gpt.config import ( GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledDatasetConfig, GPTSamplingConfig, - ShufflingType, ) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert @@ -16,13 +15,6 @@ logger = logging.getLogger(__name__) -@config_class() -class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): - gpu: bool = FieldUpdate(default=True) - use_loss_masking_spans: bool = FieldUpdate(default=False) - shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) - - @config_class() class GPTDataConfig(DataConfig, GPTLegacyConfig): """ @@ -44,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingDefaultConfig = FieldUpdate(default_factory=GPTSamplingDefaultConfig) + sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 431a28a0..7808158b 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -5,7 +5,7 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class +from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities @@ -17,20 +17,12 @@ @config_class() class SamplingConfig(Config): - seed: int | None = Field( - default=None, + seed: int = Field( + default=784569, desc="Seed for random sampling.", hint=FieldHint.feature, ) - @property - def updates(self) -> dict[str, typing.Any]: - return { - key: value - for key, value in self.to_serialized(verbose=FieldVerboseLevel.everything).items() - if value is not None - } - @dataclasses.dataclass(kw_only=True) class SamplingData: @@ -44,10 +36,10 @@ class SamplingData: # Using a mutable rather than an int so it's shared with all copies made with `update`. _rank_counter: typing.Iterator[int] = itertools.count - def update(self, config: SamplingConfig, **kwargs): - if config_updates := config.updates: - kwargs["config"] = self.config.to_copy(config_updates) - return dataclasses.replace(self, **kwargs) if kwargs else self + def update_config(self, update: SamplingConfig): + return dataclasses.replace( + self, config=self.config.from_dict(self.config, update, update_type=UpdateType.update) + ) def get_next_rank(self) -> int: # Counter that loops over ranks to try to distribute workloads evenly between ranks. @@ -163,7 +155,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. """ - _abstract = False + _abstract = True sampling: SamplingConfig = Field( default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", @@ -176,7 +168,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): ) def build_and_sample(self, data: SamplingData) -> SampledDataset: - return self.dataset.build_and_sample(data.update(self.sampling)) + return self.dataset.build_and_sample(data.update_config(self.sampling)) @config_class() diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 74d8a0c3..118b3039 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -45,20 +45,20 @@ class ShufflingType(str, enum.Enum): @config_class() class GPTSamplingConfig(SamplingConfig): - gpu: bool | None = Field( - default=None, + gpu: bool = Field( + default=True, desc="Enable fast sampling on GPU." " Note that random sampling works differently on GPU," " so the sample won't match the CPU equivalent.", hint=FieldHint.feature, ) - use_loss_masking_spans: bool | None = Field( - default=None, + use_loss_masking_spans: bool = Field( + default=False, desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - shuffle: ShufflingType | None = Field( - default=None, + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, desc="Shuffling strategy.", hint=FieldHint.feature, ) @@ -210,6 +210,7 @@ def build(self) -> "GPTDatasetSlice": @config_class() class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): + _abstract = False type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 92f1165d..46c8f483 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -164,8 +164,9 @@ class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateCon def _validate(self) -> None: if self.optimizer_state is None: - # TODO: Make sure it's a type - self.optimizer_state = self.format.support_optimizer + with self._set_implicit_default(): + # TODO: Make sure it's a type + self.optimizer_state = self.format.support_optimizer super()._validate() if self.optimizer_state: assert self.format.support_optimizer diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 83514c86..76f5e336 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -7,14 +7,14 @@ import torch from fast_llm import __version__ -from fast_llm.config import MISSING +from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, get_nested_dict_value, set_nested_dict_value +from fast_llm.utils import Assert logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 1b3e73bb..76c496ac 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -279,7 +279,8 @@ def _validate(self) -> None: self.tensor_rank = self.rank % self.tensor_parallel if self.tensor_parallel == 1: - self.sequence_tensor_parallel = False + with self._set_implicit_default(): + self.sequence_tensor_parallel = False self.distributed_dims = {} diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 83d3d51a..91256deb 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -79,10 +79,6 @@ def setup(self, distributed_config: DistributedConfig) -> None: def num_inputs(self) -> int: return self.sequential_micro_batches * self.num_micro_sequences - @property - def _is_setup(self) -> bool: - return hasattr(self, "_distributed") - def _validate(self) -> None: # Use the distributed properties to determine the batch size and its breakdown. # Requires post-processed distributed config args @@ -133,7 +129,8 @@ def _validate(self) -> None: " Use at your own risk." ) if self.micro_sequence_length is None: - self.micro_sequence_length = self.sequence_length + with self._set_implicit_default(): + self.micro_sequence_length = self.sequence_length self.num_micro_sequences = div(self.sequence_length, self.micro_sequence_length) super()._validate() diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 30add2f4..3a65bbc9 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -42,7 +42,8 @@ class IntervalConfig(Config): def _validate(self) -> None: if self.interval: - self.offset %= self.interval + with self._set_implicit_default(): + self.offset %= self.interval super()._validate() def enabled(self, iteration: int | None = None) -> bool: @@ -109,6 +110,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) + post_alerts: bool = Field(init=False, repr=False) def _validate(self) -> None: if self.status_updates is None: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467c..fa5d4920 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -60,7 +60,8 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled + with self._set_implicit_default(): + self.use_position_embeddings = not self.transformer.rotary.enabled super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @@ -175,14 +176,14 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.transformer.init_method_std is None: - self.transformer.init_method_std = self.transformer.hidden_size**-0.5 - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) + self.transformer.validate() + with self._set_implicit_default(): + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min + if self.init_method_max_embed is not None and self.init_method_min_embed is not None: + Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1352c7f0..13983137 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -250,12 +250,13 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): ) def _validate(self) -> None: - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + if self.kv_channels is None: + self.kv_channels = div(self.hidden_size, self.num_attention_heads) + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu self.projection_size = self.num_attention_heads * self.kv_channels self.num_unshared_experts = self.num_experts - self.num_shared_experts @@ -569,46 +570,47 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: - self.mlp_lr_scale = [None] - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + with self._set_implicit_default(): + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 + if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: + self.mlp_lr_scale = [None] + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) diff --git a/fast_llm/profile.py b/fast_llm/profile.py index a0fc3946..a3902cf1 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -94,7 +94,9 @@ def _validate(self) -> None: self.global_attention_layers = set() profile_ranks = set(self.ranks or []) Assert.eq(len(profile_ranks), len(self.ranks or [])) - self.ranks = profile_ranks # noqa + with self._set_implicit_default(): + self.ranks = profile_ranks # noqa + super()._validate() def get_profiler( self, *, distributed_config: DistributedConfig | None = None, start_step: int = 0 diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d650fa94..b0e48231 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -249,40 +249,6 @@ def normalize_probabilities(p: "npt.ArrayLike", return_array: bool = False) -> " return out if return_array else out.tolist() -def set_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...], value: ValueType) -> None: - if isinstance(keys, tuple): - for key in keys[:-1]: - d = d.setdefault(key, {}) - assert isinstance(d, dict) - d[keys[-1]] = value - else: - d[keys] = value - - -def get_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: - if isinstance(keys, tuple): - for key in keys: - d = d[key] - return d - else: - return d[keys] - - -def pop_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: - if isinstance(keys, tuple): - for key in keys[:-1]: - d = d[key] - return d.pop(keys[-1]) - else: - return d.pop(keys) - - class InvalidObject: """ Store an error and raise it if accessed. diff --git a/tests/data/common.py b/tests/data/common.py index 917b4914..bdfd54a7 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -5,12 +5,13 @@ import torch from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig, GPTSamplingDefaultConfig +from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import ( GPTIndexedDatasetConfig, GPTSampledDatasetConfig, + GPTSamplingConfig, GPTSamplingData, ShufflingType, ) @@ -39,7 +40,7 @@ def get_sampling_data( ) -> GPTSamplingData: # Config with convenient defaults. return GPTSamplingData( - config=GPTSamplingDefaultConfig( + config=GPTSamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -76,7 +77,7 @@ def get_test_data_and_compare_samples( distributed_config = DistributedConfig(seed=seed if legacy else 87522) distributed = Distributed(distributed_config, use_cpu=True) assert "sampling" not in config - config["sampling"] = GPTSamplingDefaultConfig( + config["sampling"] = GPTSamplingConfig( seed=87522 if legacy else seed, gpu=gpu, shuffle=shuffle, From f26010ef9f8cfd070734751f9dec45a364496308 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:21:45 -0400 Subject: [PATCH 02/18] Update pretrained config --- fast_llm/config.py | 5 +-- fast_llm/engine/checkpoint/config.py | 1 + fast_llm/engine/checkpoint/distributed.py | 6 +-- fast_llm/engine/huggingface/config.py | 5 +-- fast_llm/engine/huggingface/model.py | 8 ++-- fast_llm/engine/multi_stage/config.py | 42 +++---------------- fast_llm/engine/multi_stage/fast_llm_model.py | 7 ++-- fast_llm/layers/transformer/config.py | 17 ++++---- fast_llm/layers/transformer/mlp.py | 4 +- 9 files changed, 28 insertions(+), 67 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 326845f0..5436a294 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -374,7 +374,6 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True - print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -713,8 +712,8 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str, typing.Any]], - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + default: "Config| dict[str, typing.Any]]", + *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, ) -> typing.Self: diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 46c8f483..621f7fe8 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -200,6 +200,7 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): + # TODO!!!!!!! _abstract = False load_config: ModelConfigType = Field( diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 9c171bef..a920a52c 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -13,7 +13,6 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, DistributedCheckpointFormat, - ModelConfigType, export_safetensors_metadata, ) from fast_llm.engine.checkpoint.safe_load import SafeLoad @@ -43,15 +42,14 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: # TODO: More safety checks - loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) - loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata) num_shards = self.get_num_shards(config) shard_names = self.get_shard_names(config) Assert.eq(metadata.shards[:num_shards], list(shard_names)) same_format = ( - loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) + type(metadata.config) == type(self._model.config) and config.optimizer_state + and metadata.config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) ) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index e02abc28..e79857c9 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -73,10 +73,7 @@ def _get_config_dict( torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: updates[("distributed", "training_dtype")] = torch_dtype - fast_llm_config = cls.model_config_class.from_metadata( - pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates - ) - + fast_llm_config = cls.model_config_class.from_dict(metadata.config, kwargs.pop("fast_llm_config", {}), updates) config_dict = {"fast_llm_config": fast_llm_config} return config_dict, kwargs diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index 499f0af1..e4f2cd99 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -73,15 +73,13 @@ def from_pretrained( format=FastLLMCheckpointFormat, ) - config_updates = {} + updates = {} torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: - config_updates[("distributed", "training_dtype")] = torch_dtype + updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained( - pretrained_model_name_or_path, config_updates=config_updates, mode=mode - ) + fast_llm_model = cls.model_class.from_pretrained(pretrained_model_name_or_path, updates, mode=mode) config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d6997105..d8333c9b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -246,46 +246,12 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: @classmethod def from_pretrained( - cls, - pretrained: CheckpointLoadMetadataConfig, - default: typing.Self | None = None, + cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any] ) -> typing.Self: # TODO: Add *updates? assert pretrained.path is not None metadata = cls.load_metadata(pretrained) - return cls.from_metadata(pretrained, metadata, default) - - @classmethod - def from_metadata( - cls, - pretrained: CheckpointLoadMetadataConfig, - metadata: "CheckpointMetadata", - default: typing.Self | None = None, - updates: dict[str | tuple[str, ...], typing.Any] | None = None, - ) -> typing.Self: - # TODO: Standardize to *updates? - # TODO v0.3: Update, remove support for older checkpoints. - if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2): - raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}") - pretrained_config = cls.from_dict(metadata.config) - if not pretrained.load_config.load_architecture: - assert default is not None - config = default.to_copy() - config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.load_config.load_fast_llm: - config = pretrained_config - else: - with NoAutoValidate(): - config = cls() if default is None else default.to_copy() - if pretrained.load_config.load_base_model: - config.base_model = pretrained_config.base_model - else: - config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) - config.validate() - - if updates: - config = config.to_copy(updates) - return config + return cls.from_dict(metadata.config, *updates) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -328,7 +294,7 @@ def _validate(self) -> None: self.pretrained.setup(self.model) self.pretrained.validate() if self.pretrained.path is not None: - self.model = self.model.from_pretrained(self.pretrained, default=self.model) + self.model = self.model.from_pretrained(self.pretrained, self.model) self._setup() super()._validate() @@ -380,6 +346,8 @@ def _validate(self) -> None: self.format = self.model.get_checkpoint_format(self.format) super()._validate() + if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2): + raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}") Assert.eq(self.config.__class__, self.model) @classmethod diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index b268ec29..22e5ccac 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,6 +1,7 @@ import logging import typing +from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.distributed.distributed import Distributed @@ -45,9 +46,7 @@ def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] def from_pretrained( cls, pretrained_config: CheckpointLoadConfig, - default_config: FastLLMModelConfig = None, - *, - config_updates: dict[str | tuple[str, ...], typing.Any] | None = None, + *updates: dict[str | tuple[str, ...], typing.Any], optimizer_state_names: tuple[str, ...] | None = None, setup: bool = True, mode: StageMode = StageMode.training, @@ -55,7 +54,7 @@ def from_pretrained( stage_filter: set | None = None, ) -> typing.Self: metadata = cls.config_class.load_metadata(pretrained_config) - config = cls.config_class.from_metadata(pretrained_config, metadata, default_config, config_updates) + config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: # TODO v0.3: Make metadata.shards mandatory? if metadata.shards: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 13983137..9410157b 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -532,8 +532,8 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: list[float | None] = Field( - default_factory=list, + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, @@ -581,8 +581,6 @@ def _validate(self) -> None: self.init_method_std_mlp_1 = self.init_method_std if self.init_method_std_mlp_2 is None: self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: - self.mlp_lr_scale = [None] if self.init_method_max_qkv is None: self.init_method_max_qkv = self.init_method_max if self.init_method_min_qkv is None: @@ -614,10 +612,13 @@ def _validate(self) -> None: super()._validate() Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) - Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242d..ff4eaf26 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -45,7 +45,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale), + lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +55,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale), + lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, ) From b930a391b37703e7dce23fdb544b08fe98d42084 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:27:40 -0400 Subject: [PATCH 03/18] stuff --- fast_llm/config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 326845f0..5436a294 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -374,7 +374,6 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True - print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -713,8 +712,8 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str, typing.Any]], - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + default: "Config| dict[str, typing.Any]]", + *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, ) -> typing.Self: From 8117c47b483c26853bf5015ef85b4e94472de1b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:40:37 -0400 Subject: [PATCH 04/18] fixes --- fast_llm/engine/multi_stage/config.py | 7 +++---- fast_llm/engine/multi_stage/fast_llm_model.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d8333c9b..342a453b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -10,6 +10,7 @@ Field, FieldHint, NoAutoValidate, + UpdateType, ValidationError, check_field, config_class, @@ -248,13 +249,11 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: def from_pretrained( cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any] ) -> typing.Self: - # TODO: Add *updates? - assert pretrained.path is not None - metadata = cls.load_metadata(pretrained) - return cls.from_dict(metadata.config, *updates) + return cls.from_dict(cls.load_metadata(pretrained).config, *updates, update_type=UpdateType.update) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": + assert config.path is not None with NoAutoValidate(): metadata = config.format.get_handler_class().load_metadata(config) try: diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 22e5ccac..2dec7959 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -36,11 +36,11 @@ def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - fast_llm_metadata = self.config_class.load_metadata(config) + metadata = self.config_class.load_metadata(config) converter = config.format.get_handler_class()(self) - converter.load(config, fast_llm_metadata) + converter.load(config, metadata) self._finalize_load(reset_optimizer=not config.optimizer_state) - return fast_llm_metadata.metadata + return metadata.metadata @classmethod def from_pretrained( From 1c995d3e76be57ec80f9f305d83b613e0c8bdba3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:50:00 -0400 Subject: [PATCH 05/18] fix --- fast_llm/engine/checkpoint/config.py | 16 ---------------- fast_llm/engine/checkpoint/distributed.py | 6 +++--- fast_llm/engine/checkpoint/huggingface.py | 2 +- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 621f7fe8..a3472523 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -200,24 +200,8 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): - # TODO!!!!!!! _abstract = False - load_config: ModelConfigType = Field( - default=ModelConfigType.architecture, - desc="Configuration to save/load.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - super()._validate() - if self.format.enforce_architecture_match: - assert self.load_config.load_architecture - - @property - def compare_log_fn(self): - return ValueError if self.load_config.load_architecture else logger.warning - @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index a920a52c..953cdef8 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -67,11 +67,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) else: log_main_rank("Checkpoint format doesn't match, using safe load") - self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) + self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context: - for rank in range(loaded_config.distributed.world_size): + for rank in range(metadata.config.distributed.world_size): loaded_model = self._model.__class__( - loaded_config.to_copy({("distributed", "rank"): rank}), + metadata.config.to_copy({("distributed", "rank"): rank}), optimizer_state_names=shard_names[1:], verbose=False, ) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 87651dc4..d4533663 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -41,7 +41,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: assert not config.optimizer_state - self._model.config.base_model.compare_architecture(metadata.config.base_model, config.compare_log_fn) + self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) super().load(config, metadata) @classmethod From 506fe92917b28fc2d865edf69bad9827c5f92bfa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 16:04:35 -0400 Subject: [PATCH 06/18] fixes --- fast_llm/config.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 5436a294..222a3ec7 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -266,13 +266,21 @@ def config_class(cls=None): def wrap(cls): Assert.custom(issubclass, cls, Config) - wrapped = _process_config_class(dataclasses.dataclass(cls)) + if hasattr(cls, "__post_init__"): + raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") + + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) wrapped_init = cls.__init__ def __init__(self, **kwargs): + # This is similar to `__post_init__`, but has access to the list of arguments passed to `__init__`. wrapped_init(self, **kwargs) self._explicit_fields = set(kwargs) + self._validated = False + self._setting_implicit_default = False + if _AUTO_VALIDATE: + self.validate() cls.__init__ = __init__ return wrapped @@ -310,17 +318,6 @@ class Config: # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool = Field(init=False, repr=False) - def __post_init__(self): - """ - Perform validation unless prevented with `NoAutoValidate`. - In general this should not be overridden in derived classes, - and all post-processing should be done in `_validate` - """ - self._validated = False - self._setting_implicit_default = False - if _AUTO_VALIDATE: - self.validate() - def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -983,13 +980,15 @@ def set_nested_dict_value[ raise ValueError("Cannot update an already instantiated config.") elif isinstance(value, Config): raise ValueError("Cannot update a config dict with an already instantiated config.") - elif isinstance(d, dict): + elif isinstance(value, dict): if key in d: Assert.custom(isinstance, d[key], dict) else: d[key] = {} for key_, value_ in value.items(): set_nested_dict_value(d, key_, value_, update_type) + elif isinstance(d[key], dict): + raise ValueError("Cannot replace a dict with a non-dict value.") elif ( isinstance(value, (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) From 971d3ef23297f7dd64550facff25f8609c0fb097 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 18:32:07 -0400 Subject: [PATCH 07/18] fixes --- fast_llm/config.py | 14 +++++++++----- fast_llm/engine/huggingface/config.py | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 222a3ec7..7cb54919 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -90,12 +90,12 @@ class FieldHint: class FieldVerboseLevel: - nothing = -1 + explicit = None core = 0 optional = 10 performance = 20 debug = 50 - everything = None + everything = 2**31 FieldHintDoc = { @@ -680,7 +680,7 @@ def to_copy[ ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: return self.from_dict(self, *updates, strict=strict) - def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]: + def to_serialized(self, verbose: int | None = FieldVerboseLevel.explicit) -> dict[str, typing.Any]: return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) def to_logs[ @@ -863,8 +863,12 @@ def _handle_renamed_field( def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? - self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) - other_dict = other._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) + self_dict = self._to_dict( + format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything + ) + other_dict = other._to_dict( + format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything + ) compare = { key: (self_dict.get(key, MISSING), other_dict.get(key, MISSING)) for key in self_dict.keys() | other_dict.keys() diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index e02abc28..2b240e4b 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,6 +5,7 @@ import transformers +from fast_llm.config import FieldVerboseLevel from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig, FastLLMCheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig @@ -90,12 +91,12 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=None) + out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: out = super().to_diff_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized() + out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.explicit) return out def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None: From 6bf20cb2d72faabbf5eb6eea4de4f46180f836f8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 21:26:59 -0400 Subject: [PATCH 08/18] Tests wip --- fast_llm/config.py | 40 ++++-- tests/config/__init__.py | 0 tests/config/common.py | 37 ++++++ tests/config/test_field.py | 176 ++++++++++++++++++++++++++ tests/data/test_dataset_from_file.py | 1 - tests/data/test_prepare_gpt_memmap.py | 1 - tests/test_config.py | 58 +++------ 7 files changed, 258 insertions(+), 55 deletions(-) create mode 100644 tests/config/__init__.py create mode 100644 tests/config/common.py create mode 100644 tests/config/test_field.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 7cb54919..67aa5b7a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,4 +1,5 @@ import contextlib +import copy import dataclasses import enum import logging @@ -316,7 +317,7 @@ class Config: _explicit_fields: set[str] = Field(init=False, repr=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False, repr=False) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -332,12 +333,20 @@ def __setattr__(self, key: str, value: typing.Any) -> None: f"Cannot set attribute `{key}`" f" in configuration class `{get_type_name(type(self))}` after validation." ) - elif not getattr(self, "_setting_implicit_default", True): - field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: - # Adding to explicit field list except within `_set_implicit_default` context - # and during dataclass initialization (`_setting_implicit_default` not yet set). - self._explicit_fields.add(key) + if getattr(self, "_setting_implicit_default", None) is not None: + if self._setting_implicit_default: + if key in self._explicit_fields: + raise RuntimeError( + f"Trying to set an implicit default for field `{key}`," + f"but the field has already been set explicitly." + ) + else: + field = self.get_field(key) + if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + # Adding to explicit field list except within `_set_implicit_default` context, + # during dataclass initialization (`_setting_implicit_default` not yet set) + # and during automated config validation (`_setting_implicit_default=None`) + self._explicit_fields.add(key) super().__setattr__(key, value) def __delattr__(self, key: str) -> None: @@ -352,8 +361,9 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self): - self._setting_implicit_default = True + def _set_implicit_default(self, _value: bool | int = True): + assert self._setting_implicit_default is False + self._setting_implicit_default = _value yield self._setting_implicit_default = False @@ -383,7 +393,7 @@ def _validate(self) -> None: """ self._check_abstract() errors = [] - with self._set_implicit_default(): + with self._set_implicit_default(None): for name, field in self.fields(): if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue @@ -567,7 +577,7 @@ def get_field(cls, name: str) -> Field: def _to_dict( self, - verbose: int | None = None, + verbose: int | None = FieldVerboseLevel.explicit, all_fields: bool = False, format_: _ConfigDictFormat = _ConfigDictFormat.nested, serializable: bool = False, @@ -716,6 +726,8 @@ def from_dict( ) -> typing.Self: if isinstance(default, Config): default = default._to_dict() + else: + default = copy.deepcopy(default) for update in updates: if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) @@ -980,7 +992,7 @@ def set_nested_dict_value[ d[key] = value elif update_type == UpdateType.update: # TODO: Improve error messages, ex. for nested cases? - if isinstance(d[key], Config): + if isinstance(d.get(key), Config): raise ValueError("Cannot update an already instantiated config.") elif isinstance(value, Config): raise ValueError("Cannot update a config dict with an already instantiated config.") @@ -991,13 +1003,13 @@ def set_nested_dict_value[ d[key] = {} for key_, value_ in value.items(): set_nested_dict_value(d, key_, value_, update_type) - elif isinstance(d[key], dict): + elif isinstance(d.get(key), dict): raise ValueError("Cannot replace a dict with a non-dict value.") elif ( isinstance(value, (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) ) or ( - isinstance(d[key], (list, set, tuple)) + isinstance(d.get(key), (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) ): raise ValueError("Update not supported for nested lists.") diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/config/common.py b/tests/config/common.py new file mode 100644 index 00000000..3109175a --- /dev/null +++ b/tests/config/common.py @@ -0,0 +1,37 @@ +import enum +import pathlib + +from fast_llm.config import Config, Field, FieldHint, config_class + + +class TestEnum(str, enum.Enum): + a = "a" + b = "b" + c = "c" + + +@config_class +class TestConfig(Config): + int_field: int = Field(default=0, hint=FieldHint.optional) + bool_field: bool = Field(default=False, hint=FieldHint.optional) + str_field: str = Field(default="", hint=FieldHint.optional) + path_field: pathlib.Path = Field(default="", hint=FieldHint.optional) + float_field: float = Field(default=4.0, hint=FieldHint.optional) + optional_field: str | None = Field(default=None, hint=FieldHint.optional) + union_field: str | int = Field(default=7, hint=FieldHint.optional) + implicit_field: str = Field(default=None, hint=FieldHint.optional) + list_field: list[int] = Field(default_factory=list, hint=FieldHint.optional) + tuple_field: tuple[int, ...] = Field(default=(), hint=FieldHint.optional) + # tuple_fixed_length_field: tuple[int, str] = Field(default=(5, "text"), hint=FieldHint.optional) + set_field: set[int] = Field(default_factory=set, hint=FieldHint.optional) + dict_field: dict[int, int] = Field(default_factory=dict, hint=FieldHint.optional) + type_field: type[int] = Field(default=int, hint=FieldHint.optional) + enum_field: TestEnum = Field(default=TestEnum.a, hint=FieldHint.optional) + core_field: int = Field(default=4, hint=FieldHint.core) + complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.implicit_field is None: + self.implicit_field = "implicit" + super()._validate() diff --git a/tests/config/test_field.py b/tests/config/test_field.py new file mode 100644 index 00000000..27e7c8b5 --- /dev/null +++ b/tests/config/test_field.py @@ -0,0 +1,176 @@ +import math +import pathlib + +import numpy +import pytest + +from fast_llm.config import FieldVerboseLevel, ValidationError +from fast_llm.utils import Assert +from tests.config.common import TestConfig, TestEnum + + +def check_config(internal_config, *alternate, serialized_config=None): + serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config + for init_config in (internal_config, *alternate): + config = TestConfig.from_dict(init_config) + Assert.eq(config.to_serialized(), serialized_config) + Assert.eq(config._to_dict(), internal_config) + + +def check_invalid_config(config): + with pytest.raises(ValidationError): + TestConfig.from_dict(config) + + +def test_create_and_serialize_config(): + Assert.eq(TestConfig.from_dict({}).to_serialized(), {}) + + +@pytest.mark.parametrize("value", (0, -6, 3, True)) +def test_int_field(value): + check_config({"int_field": value}) + + +@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4])) +def test_int_field_invalid(value): + check_invalid_config({"int_field": value}) + + +@pytest.mark.parametrize("value", (True, False)) +def test_bool_field(value): + check_config({"bool_field": value}) + + +@pytest.mark.parametrize("value", (1, "True", None, [True])) +def test_bool_field_invalid(value): + check_invalid_config({"bool_field": value}) + + +@pytest.mark.parametrize("value", ("", "text", "1", TestEnum.a)) +def test_str_field(value): + check_config({"str_field": value}) + + +@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"))) +def test_str_field_invalid(value): + check_invalid_config({"str_field": value}) + + +@pytest.mark.parametrize("value", (".", "text", "/a/b/c.d")) +def test_path_field(value): + check_config({"path_field": pathlib.Path(value)}, {"path_field": value}) + + +@pytest.mark.parametrize("value", (1, True, None, [pathlib.Path("a")])) +def test_path_field_invalid(value): + check_invalid_config({"path_field": value}) + + +@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, True, numpy.float64(3), math.nan)) +def test_float_field(value): + check_config({"float_field": value}) + + +@pytest.mark.parametrize("value", (None, [4.7], "0.0")) +def test_float_field_invalid(value): + check_invalid_config({"float_field": value}) + + +@pytest.mark.parametrize("value", ("", None, "text")) +def test_optional_field(value): + check_config({"optional_field": value}) + + +@pytest.mark.parametrize("value", (True, 6, [None])) +def test_optional_field_invalid(value): + check_invalid_config({"optional": value}) + + +@pytest.mark.parametrize("value", ("", 0, True, "text", 7)) +def test_union_field(value): + check_config({"union_field": value}) + + +@pytest.mark.parametrize("value", (6.0, [""])) +def test_union_field_invalid(value): + check_invalid_config({"optional": value}) + + +@pytest.mark.parametrize("value", ("implicit", "", "text")) +def test_implicit_field(value): + check_config({"implicit_field": value}) + + +TUPLE_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_list_field(value): + check_config( + {"list_field": list(value)}, + {"list_field": value}, + serialized_config={"list_field": list(value)}, + ) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_tuple_field(value): + check_config( + {"tuple_field": list(value)}, + {"tuple_field": value}, + serialized_config={"tuple_field": list(value)}, + ) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_set_field(value): + check_config( + {"set_field": list(set(value))}, + {"set_field": set(value)}, + {"set_field": list(value)}, + {"set_field": tuple(value)}, + serialized_config={"set_field": list(set(value))}, + ) + + +# @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (True, "True"))) +# def test_tuple_fixed_length_field(value): +# expected_config = {"tuple_variable_length_field": value} +# Assert.eq(TestConfig.from_dict(expected_config).to_serialized(), expected_config) +# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": list(value)}).to_serialized(), expected_config) +# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) + + +@pytest.mark.parametrize("value", ({}, {True: 2}, {1: 2, 3: 4})) +def test_dict_field(value): + check_config({"dict_field": value}) + + +class IntClass(int): + pass + + +@pytest.mark.parametrize("value", (int, bool, IntClass)) +def test_type_field(value): + check_config({"type_field": value}, serialized_config={"type_field": str(value)}) + + +@pytest.mark.parametrize("value", (TestEnum.a, TestEnum.b, TestEnum.c)) +def test_enum_field(value): + check_config({"enum_field": value}, {"enum_field": value.value}) + + +def test_core_field(): + Assert.eq(TestConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + + +@pytest.mark.parametrize( + "value", + ( + {}, + {3: None, "text": [], False: [["", 3], ["a", -7]]}, + {0: [[".", 8]]}, + ), +) +def test_complex_field(value): + check_config({"complex_field": value}) diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 4ac2fcdf..280b3413 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -8,5 +8,4 @@ def test_dataset_from_file(): get_test_dataset() dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() - print("kjhbwiugfberibgiujebi", len(dataset)) compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9a15a051..a6fd3246 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -148,7 +148,6 @@ def test_split_datasets_1(): { "training": { "type": "blended", - "name": "blended", "datasets": [ dataset_config_0.to_serialized(), { diff --git a/tests/test_config.py b/tests/test_config.py index 7141812a..5c45db0b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,19 +1,14 @@ import pathlib -import pytest import subprocess import unittest.mock -import yaml +import pytest +import yaml -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerArchitectureConfig, - AddLinearBiasChoices, -) -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.config import ValidationError - +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry @@ -90,33 +85,6 @@ def test_do_use_flash_attention(): config.do_use_flash_attention(mock_distributed_config) -def test_add_linear_biases_valid_values(): - # Valid boolean values - assert TransformerArchitectureConfig(add_linear_biases=True).add_linear_biases is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_linear_biases is False - - # Valid enum values - assert TransformerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases == AddLinearBiasChoices.nowhere - assert ( - TransformerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases - == AddLinearBiasChoices.everywhere - ) - assert ( - TransformerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases == AddLinearBiasChoices.only_attn_qkv - ) - - -def test_add_linear_biases_invalid_values(): - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases="invalid_value") - - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=123) - - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=None) - - def test_add_mlp_bias(): assert TransformerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True assert TransformerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False @@ -130,7 +98,9 @@ def test_add_attn_qkv_bias(): assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + assert ( + TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + ) def test_add_attn_dense_bias(): @@ -138,4 +108,14 @@ def test_add_attn_dense_bias(): assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias is True assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias is False + assert ( + TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias + is False + ) + + +@pytest.mark.parametrize("cls", (GPTSamplingConfig,)) +def test_serialize_default_config_updates(cls): + # Config classes used as config updates should have a default that serializes to an empty dict + # so no value is incorrectly overridden. + assert cls.from_dict({}).to_serialized() == {} From c13fb19f8763b0aebe83058b375b9732e70721d2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 28 Mar 2025 22:28:57 -0400 Subject: [PATCH 09/18] misc --- fast_llm/config.py | 29 +++++----- fast_llm/data/dataset/gpt/config.py | 4 +- fast_llm/utils.py | 2 +- tests/config/common.py | 6 +- tests/config/test_field.py | 86 ++++++++++++++++++++++------- 5 files changed, 88 insertions(+), 39 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 67aa5b7a..c311abf4 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -468,10 +468,10 @@ def _validate_element(cls, value, type_, name: str): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type.") elif issubclass(type_, Config): - cls._validate_element_type(value, type_, name) + cls._validate_element_type(value, type_, strict=False) value.validate(_is_validating=True) else: - value = cls._validate_simple(value, type_, name) + value = cls._validate_simple(value, type_) return value @classmethod @@ -491,7 +491,7 @@ def _validate_union(cls, value, type_, name: str): @classmethod def _validate_array(cls, value, type_, name: str): origin = type_.__origin__ - cls._validate_element_type(value, (origin, list, tuple), name) + cls._validate_element_type(value, (origin, list, tuple), strict=False) args = getattr(type_, "__args__", [typing.Any, ...] if origin is tuple else [typing.Any]) errors = [] if issubclass(origin, tuple) and not (len(args) == 2 and args[1] is ...): @@ -518,7 +518,7 @@ def _validate_dict(cls, value, type_, name: str): if len(args) > 2: raise FieldTypeError(f"Invalid dict specification `{get_type_name(type_)}` for field `{name}`") args.extend([typing.Any for _ in range(2 - len(args))]) - cls._validate_element_type(value, type_.__origin__, name) + cls._validate_element_type(value, type_.__origin__, strict=False) errors = [] new_value = {} old_keys = {} @@ -534,19 +534,22 @@ def _validate_dict(cls, value, type_, name: str): return new_value @classmethod - def _validate_simple(cls, value, type_, name: str): + def _validate_simple(cls, value, type_, strict: bool = True): if hasattr(type_, "__fast_llm_validator__"): value = type_.__fast_llm_validator__(value) - elif type_ is float and isinstance(value, int): + elif type_ is float and type(value) == int: # Ints are ok too. value = float(value) elif issubclass(type_, enum.Enum) and not isinstance(value, type_) and issubclass(type_, type(value)): # Enum values are ok too. value = type_(value) - elif issubclass(type_, pathlib.PurePath) and isinstance(value, str): - # Str paths are ok too. - value = type_(value) - cls._validate_element_type(value, type_, name) + elif issubclass(type_, pathlib.PurePath): + if isinstance(value, str): + # Str paths are ok too. + value = type_(value) + # Path type may depend on the OS. + strict = False + cls._validate_element_type(value, type_, strict) return value @classmethod @@ -560,9 +563,9 @@ def _validate_type(cls, value, type_: type | tuple[type, ...], name): raise ValidationError(f"Field value `{value} is not a subclass of `{get_type_name(type_)}`") @classmethod - def _validate_element_type(cls, value, type_: type | tuple[type, ...], name): - if not isinstance(value, type_): - raise ValidationError(f"Unexpected type `{get_type_name(type(value))}`") + def _validate_element_type(cls, value, type_: type | tuple[type, ...], strict: bool = True): + if not (type(value) == type_ if strict else isinstance(value, type_)): + raise ValidationError(f"Unexpected field type: {get_type_name(type(value))} != {get_type_name(type_)}") @classmethod def fields(cls) -> typing.Iterable[tuple[str, Field]]: diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 118b3039..4f15492a 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -484,8 +484,8 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: "type": "slice", # TODO: this duplicates memmap datasets for each phase. "dataset": {"type": "memmap", "path": prefix}, - "begin": phase_splits[phase_index], - "end": phase_splits[phase_index + 1], + "begin": float(phase_splits[phase_index]), + "end": float(phase_splits[phase_index + 1]), } for prefix in dataset_prefixes ] diff --git a/fast_llm/utils.py b/fast_llm/utils.py index aac6f607..4edd8b98 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -86,7 +86,7 @@ class Assert: @staticmethod def eq(x, *args, msg=None): for arg in args: - assert x == arg, f"{x} != {arg} " + f"| {msg}" if msg else "" + assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") @staticmethod def is_(x, y): diff --git a/tests/config/common.py b/tests/config/common.py index 3109175a..143d770c 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -4,14 +4,14 @@ from fast_llm.config import Config, Field, FieldHint, config_class -class TestEnum(str, enum.Enum): +class ExampleEnum(enum.StrEnum): a = "a" b = "b" c = "c" @config_class -class TestConfig(Config): +class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) str_field: str = Field(default="", hint=FieldHint.optional) @@ -26,7 +26,7 @@ class TestConfig(Config): set_field: set[int] = Field(default_factory=set, hint=FieldHint.optional) dict_field: dict[int, int] = Field(default_factory=dict, hint=FieldHint.optional) type_field: type[int] = Field(default=int, hint=FieldHint.optional) - enum_field: TestEnum = Field(default=TestEnum.a, hint=FieldHint.optional) + enum_field: ExampleEnum = Field(default=ExampleEnum.a, hint=FieldHint.optional) core_field: int = Field(default=4, hint=FieldHint.core) complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 27e7c8b5..4f39f741 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -6,32 +6,59 @@ from fast_llm.config import FieldVerboseLevel, ValidationError from fast_llm.utils import Assert -from tests.config.common import TestConfig, TestEnum +from tests.config.common import ExampleConfig, ExampleEnum + + +def _check_equal(config_a, config_b): + # Check for equality of both values and types. + for key in config_a.keys() | config_b.keys(): + assert key in config_a and key in config_b, key + Assert.eq(type(config_a[key]), type(config_b[key])) + if isinstance(config_a[key], (list, tuple, set)): + Assert.eq(len(config_a[key]), len(config_b[key])) + for i in range(len(config_a[key])): + _check_equal({"": config_a[key][i]}, {"": config_b[key][i]}) + elif isinstance(config_a[key], dict): + _check_equal(config_a[key], config_b[key]) + else: + try: + Assert.eq(config_a[key], config_b[key]) + except AssertionError as e: + # Special case for `math.nan` + if config_a[key] is not config_b[key]: + raise e + + +def check_equal(config_a, config_b): + try: + _check_equal(config_a, config_b) + except AssertionError as e: + raise AssertionError(config_a, config_b, *e.args) def check_config(internal_config, *alternate, serialized_config=None): serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config for init_config in (internal_config, *alternate): - config = TestConfig.from_dict(init_config) - Assert.eq(config.to_serialized(), serialized_config) - Assert.eq(config._to_dict(), internal_config) + config = ExampleConfig.from_dict(init_config) + check_equal(config.to_serialized(), serialized_config) + check_equal(config._to_dict(), internal_config) def check_invalid_config(config): with pytest.raises(ValidationError): - TestConfig.from_dict(config) + ExampleConfig.from_dict(config) def test_create_and_serialize_config(): - Assert.eq(TestConfig.from_dict({}).to_serialized(), {}) + Assert.eq(ExampleConfig.from_dict({}).to_serialized(), {}) -@pytest.mark.parametrize("value", (0, -6, 3, True)) +@pytest.mark.parametrize("value", (0, -6, 3)) def test_int_field(value): check_config({"int_field": value}) -@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4])) +@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4], True)) def test_int_field_invalid(value): check_invalid_config({"int_field": value}) @@ -46,12 +73,12 @@ def test_bool_field_invalid(value): check_invalid_config({"bool_field": value}) -@pytest.mark.parametrize("value", ("", "text", "1", TestEnum.a)) +@pytest.mark.parametrize("value", ("", "text", "1")) def test_str_field(value): - check_config({"str_field": value}) + check_config({"str_field": str(value)}, {"str_field": value}) -@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"))) +@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a)) def test_str_field_invalid(value): check_invalid_config({"str_field": value}) @@ -66,12 +93,14 @@ def test_path_field_invalid(value): check_invalid_config({"path_field": value}) -@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, True, numpy.float64(3), math.nan)) +@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, math.nan)) def test_float_field(value): - check_config({"float_field": value}) + check_config( + {"float_field": float(value)}, {"float_field": value}, serialized_config={"float_field": float(value)} + ) -@pytest.mark.parametrize("value", (None, [4.7], "0.0")) +@pytest.mark.parametrize("value", (None, [4.7], "0.0", True, numpy.float64(3))) def test_float_field_invalid(value): check_invalid_config({"float_field": value}) @@ -86,16 +115,20 @@ def test_optional_field_invalid(value): check_invalid_config({"optional": value}) -@pytest.mark.parametrize("value", ("", 0, True, "text", 7)) +@pytest.mark.parametrize("value", ("", 0, "text", 7)) def test_union_field(value): check_config({"union_field": value}) -@pytest.mark.parametrize("value", (6.0, [""])) +@pytest.mark.parametrize("value", (6.0, [""], True)) def test_union_field_invalid(value): check_invalid_config({"optional": value}) +def test_implicit_field_value(): + Assert.eq(ExampleConfig.from_dict({}).implicit_field, "implicit") + + @pytest.mark.parametrize("value", ("implicit", "", "text")) def test_implicit_field(value): check_config({"implicit_field": value}) @@ -141,11 +174,16 @@ def test_set_field(value): # Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) -@pytest.mark.parametrize("value", ({}, {True: 2}, {1: 2, 3: 4})) +@pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) def test_dict_field(value): check_config({"dict_field": value}) +@pytest.mark.parametrize("value", ({True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text")) +def test_dict_field_invalid(value): + check_invalid_config({"dict_field": value}) + + class IntClass(int): pass @@ -155,22 +193,30 @@ def test_type_field(value): check_config({"type_field": value}, serialized_config={"type_field": str(value)}) -@pytest.mark.parametrize("value", (TestEnum.a, TestEnum.b, TestEnum.c)) +@pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) def test_enum_field(value): check_config({"enum_field": value}, {"enum_field": value.value}) def test_core_field(): - Assert.eq(TestConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + Assert.eq(ExampleConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) @pytest.mark.parametrize( "value", ( {}, - {3: None, "text": [], False: [["", 3], ["a", -7]]}, + {3: None, "text": [], 0: [["", 3], ["a", -7]]}, {0: [[".", 8]]}, ), ) def test_complex_field(value): check_config({"complex_field": value}) + + +@pytest.mark.parametrize( + "value", + ({3: None, "text": [], False: [["", 3], ["a", -7]]},), +) +def test_complex_field_invalid(value): + check_invalid_config({"complex_field": value}) From a20fcecfb870fb076bfa067b8622c6a31aa4d928 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 31 Mar 2025 20:23:42 -0400 Subject: [PATCH 10/18] tests --- tests/config/common.py | 21 ++++++ tests/config/test_field.py | 133 +++++++++++++++++++++++++++---------- 2 files changed, 120 insertions(+), 34 deletions(-) diff --git a/tests/config/common.py b/tests/config/common.py index 143d770c..f9449507 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -35,3 +35,24 @@ def _validate(self) -> None: if self.implicit_field is None: self.implicit_field = "implicit" super()._validate() + + +@config_class +class ExampleVerboseConfig(Config): + # These fields will have non-empty default serialized values. + list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) + tuple_default_field: tuple[int, ...] = Field(default=(0, 1), hint=FieldHint.optional) + tuple_fixed_length_field: tuple[int, str] = Field(default=(5, "text"), hint=FieldHint.optional) + set_default_field: set[int] = Field(default_factory=lambda: {0, 1, 2}, hint=FieldHint.optional) + dict_default_field: dict[str, int] = Field(default_factory=lambda: {"0": 0, "1": 1}, hint=FieldHint.optional) + explicit_field: str = Field(default=None, hint=FieldHint.optional) + + def _validate(self) -> None: + if self.explicit_field is None: + self.explicit_field = "explicit" + super()._validate() + + +@config_class +class ExampleNestedConfig(ExampleConfig): + nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 4f39f741..bed9c181 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -4,29 +4,29 @@ import numpy import pytest -from fast_llm.config import FieldVerboseLevel, ValidationError +from fast_llm.config import Config, FieldVerboseLevel, ValidationError from fast_llm.utils import Assert -from tests.config.common import ExampleConfig, ExampleEnum +from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig def _check_equal(config_a, config_b): # Check for equality of both values and types. - for key in config_a.keys() | config_b.keys(): - assert key in config_a and key in config_b, key - Assert.eq(type(config_a[key]), type(config_b[key])) - if isinstance(config_a[key], (list, tuple, set)): - Assert.eq(len(config_a[key]), len(config_b[key])) - for i in range(len(config_a[key])): - _check_equal({"": config_a[key][i]}, {"": config_b[key][i]}) - elif isinstance(config_a[key], dict): + Assert.eq(type(config_a), type(config_b)) + if isinstance(config_a, dict): + for key in config_a.keys() | config_b.keys(): + assert key in config_a and key in config_b, key _check_equal(config_a[key], config_b[key]) - else: - try: - Assert.eq(config_a[key], config_b[key]) - except AssertionError as e: - # Special case for `math.nan` - if config_a[key] is not config_b[key]: - raise e + elif isinstance(config_a, (list, tuple, set)): + Assert.eq(len(config_a), len(config_b)) + for i in range(len(config_a)): + _check_equal(config_a[i], config_b[i]) + else: + try: + Assert.eq(config_a, config_b) + except AssertionError: + # Special case for `math.nan` + if config_a is not config_b: + raise def check_equal(config_a, config_b): @@ -36,17 +36,30 @@ def check_equal(config_a, config_b): raise AssertionError(config_a, config_b, *e.args) -def check_config(internal_config, *alternate, serialized_config=None): +def check_config( + internal_config, + *alternate, + serialized_config=None, + cls: type[Config] = ExampleConfig, + fields: list[str] | None = None, +): serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config for init_config in (internal_config, *alternate): - config = ExampleConfig.from_dict(init_config) - check_equal(config.to_serialized(), serialized_config) - check_equal(config._to_dict(), internal_config) + config = cls.from_dict(init_config) + serialized_config_ = config.to_serialized() + internal_config_ = config._to_dict() + if fields is None: + check_equal(serialized_config_, serialized_config) + check_equal(internal_config_, internal_config) + else: + for field in fields: + check_equal(serialized_config_[field], serialized_config[field]) + check_equal(internal_config_[field], internal_config[field]) -def check_invalid_config(config): +def check_invalid_config(config, cls: type[Config] = ExampleConfig): with pytest.raises(ValidationError): - ExampleConfig.from_dict(config) + cls.from_dict(config) def test_create_and_serialize_config(): @@ -134,10 +147,11 @@ def test_implicit_field(value): check_config({"implicit_field": value}) -TUPLE_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) +ARRAY_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) +ARRAY_VALUES_INVALID = (6.0, {}, True, "text") -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_list_field(value): check_config( {"list_field": list(value)}, @@ -146,7 +160,12 @@ def test_list_field(value): ) -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_list_field_invalid(value): + check_invalid_config({"list_field": value}) + + +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_tuple_field(value): check_config( {"tuple_field": list(value)}, @@ -155,7 +174,12 @@ def test_tuple_field(value): ) -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_tuple_field_invalid(value): + check_invalid_config({"tuple_field": value}) + + +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_set_field(value): check_config( {"set_field": list(set(value))}, @@ -166,12 +190,9 @@ def test_set_field(value): ) -# @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (True, "True"))) -# def test_tuple_fixed_length_field(value): -# expected_config = {"tuple_variable_length_field": value} -# Assert.eq(TestConfig.from_dict(expected_config).to_serialized(), expected_config) -# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": list(value)}).to_serialized(), expected_config) -# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_tuple_field_invalid(value): + check_invalid_config({"set_field": value}) @pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) @@ -193,9 +214,19 @@ def test_type_field(value): check_config({"type_field": value}, serialized_config={"type_field": str(value)}) +@pytest.mark.parametrize("value", (5, None, [], "text")) +def test_type_field_invalid(value): + check_invalid_config({"type_field": value}) + + @pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) def test_enum_field(value): - check_config({"enum_field": value}, {"enum_field": value.value}) + check_config({"enum_field": value}, {"enum_field": str(value)}) + + +@pytest.mark.parametrize("value", (5, None, [], "text")) +def test_enum_field_invalid(value): + check_invalid_config({"type_field": value}) def test_core_field(): @@ -220,3 +251,37 @@ def test_complex_field(value): ) def test_complex_field_invalid(value): check_invalid_config({"complex_field": value}) + + +def test_verbose_config_default(): + default_values = { + "list_default_field": [0], + "tuple_default_field": [0, 1], + "tuple_fixed_length_field": [5, "text"], + "set_default_field": [0, 1, 2], + "dict_default_field": {"0": 0, "1": 1}, + "explicit_field": "explicit", + } + config = ExampleVerboseConfig.from_dict({}) + check_equal(config.to_serialized(), default_values) + check_equal(config._to_dict(), default_values) + + +@pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) +def test_tuple_fixed_length_field(value): + check_config( + {"tuple_fixed_length_field": list(value)}, + {"tuple_fixed_length_field": value}, + serialized_config={"tuple_fixed_length_field": list(value)}, + cls=ExampleVerboseConfig, + fields=["tuple_fixed_length_field"], + ) + + +@pytest.mark.parametrize("value", ((), (5,), ("", 0), ("0", "True"), (0, "", "text"))) +def test_tuple_fixed_length_field_invalid(value): + check_invalid_config({"tuple_fixed_length_field": value}, cls=ExampleVerboseConfig) + + +# TODO: Test other fields with defaults. +# TODO: Test nested fields. From 9af372df69e71e7a818bb52f4e7d26706d42e19c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 18:07:02 -0400 Subject: [PATCH 11/18] Tests, fixes, remove tuple format --- fast_llm/config.py | 111 ++++++------------ fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 2 +- fast_llm/engine/checkpoint/distributed.py | 8 +- fast_llm/engine/checkpoint/external.py | 2 +- fast_llm/engine/checkpoint/huggingface.py | 2 +- fast_llm/engine/checkpoint/state_dict.py | 2 +- fast_llm/engine/config_utils/run.py | 4 +- fast_llm/engine/huggingface/config.py | 4 +- fast_llm/engine/training/wandb.py | 2 +- fast_llm/utils.py | 32 +++++ tests/config/common.py | 31 ++++- tests/config/test_field.py | 67 ++--------- tests/config/test_update.py | 52 ++++++++ tests/data/test_prepare_gpt_memmap.py | 20 ++-- tests/test_config.py | 2 +- tools/moe_add_experts.py | 2 +- 18 files changed, 185 insertions(+), 162 deletions(-) create mode 100644 tests/config/test_update.py diff --git a/fast_llm/config.py b/fast_llm/config.py index c311abf4..0abd9073 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -11,7 +11,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log +from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -38,13 +38,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): _AUTO_VALIDATE = self._old_value -class _ConfigDictFormat(str, enum.Enum): - # TODO v0.3: delete class - flat = "flat" - nested = "nested" - tuple = "tuple" - - class UpdateType(str, enum.Enum): # Override entries no matter what they contais. override = "override" @@ -578,33 +571,26 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]: def get_field(cls, name: str) -> Field: return cls.__dataclass_fields__[name] # noqa - def _to_dict( + def to_dict( self, verbose: int | None = FieldVerboseLevel.explicit, all_fields: bool = False, - format_: _ConfigDictFormat = _ConfigDictFormat.nested, - serializable: bool = False, + serialized: bool = True, ) -> dict[str, typing.Any]: """ Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`. - When not flat, the dict includes a `__class__` entry which allows support for derived classes. Args: all_fields: Include the derived fields, with `init=False`. - format_: The config format used to represent nested configs. Options: - * `ConfigDictFormat.nested`: Preserve the nested config structure by returning nested dicts. - Also save a `__class__` entry to support derived classes. Standard format. - * `ConfigDictFormat.tuple`: Preserve the nested config structure by returning tuples of keys. - Used for config updates. - serializable: Ensure the dict is serializable to json or yaml. Information may be lost. + serialized: Ensure the dict is serializable to json or yaml. Information may be lost. """ arg_dict = {} for name, field in self.fields(): value = getattr(self, name, MISSING) - self._add_field_to_args(arg_dict, name, field, value, verbose, all_fields, format_, serializable) + self._add_field_to_args(arg_dict, name, field, value, verbose, all_fields, serialized) if hasattr(self, "_unknown_fields"): for name, value in self._unknown_fields.items(): - self._add_field_to_args(arg_dict, f"!!! {name}", None, value, None, all_fields, format_, serializable) + self._add_field_to_args(arg_dict, f"!!! {name}", None, value, None, all_fields, serialized) return arg_dict @@ -616,13 +602,12 @@ def _add_field_to_args( value: typing.Any, verbose: int | None = None, all_fields: bool = False, - format_: _ConfigDictFormat = _ConfigDictFormat.nested, - serializable: bool = False, + serializable: bool = True, ) -> None: if ( field is not None and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not (all_fields) + and not all_fields ): # Exclude class variables and derived fields unless requested explicitly. return @@ -632,48 +617,36 @@ def _add_field_to_args( or (verbose is not None and verbose >= FieldHintImportance[field.hint]) ) if isinstance(value, Config): - field_value = value._to_dict( + field_value = value.to_dict( verbose=verbose, all_fields=all_fields, - format_=format_, - serializable=serializable, + serialized=serializable, ) # Empty configs can safely be trimmed. explicit_field = all_fields elif isinstance(value, (list, tuple, set)): - field_value = {} if format_ == _ConfigDictFormat.tuple else [] + field_value = [] for i, list_value in enumerate(value): - self._add_field_to_args( - field_value, str(i), None, list_value, verbose, all_fields, format_, serializable - ) + self._add_field_to_args(field_value, str(i), None, list_value, verbose, all_fields, serializable) elif isinstance(value, dict): field_value = {} for dict_name, dict_value in value.items(): - self._add_field_to_args( - field_value, dict_name, None, dict_value, verbose, all_fields, format_, serializable - ) + self._add_field_to_args(field_value, dict_name, None, dict_value, verbose, all_fields, serializable) elif explicit_field: field_value = value if serializable: field_value = self._serialize_value(value) - if format_ == _ConfigDictFormat.tuple: - field_value = {(): field_value} else: # Exclude unimportant (implicit or explicit) default values. return if serializable: name = self._serialize_value(name) - if format_ == _ConfigDictFormat.tuple: - args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) - elif format_ == _ConfigDictFormat.nested: - if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: - if isinstance(args, dict): - args[name] = field_value - else: - args.append(field_value) - else: - raise NotImplementedError(format_) + if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: + if isinstance(args, dict): + args[name] = field_value + else: + args.append(field_value) @classmethod def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: @@ -689,12 +662,14 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: return value def to_copy[ - T - ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: - return self.from_dict(self, *updates, strict=strict) - - def to_serialized(self, verbose: int | None = FieldVerboseLevel.explicit) -> dict[str, typing.Any]: - return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) + T: Config, + ]( + self: T, + *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + strict: bool = True, + update_type: UpdateType = UpdateType.override, + ) -> T: + return self.from_dict(self, *updates, strict=strict, update_type=update_type) def to_logs[ T @@ -706,7 +681,7 @@ def to_logs[ width: int = 80, fill_char: str = "-", ) -> T: - arg_dict = self.to_serialized(verbose=verbose) + arg_dict = self.to_dict(verbose=verbose) if title is None: title = self._get_class_name() return log_fn( @@ -728,12 +703,14 @@ def from_dict( update_type: UpdateType = UpdateType.override, ) -> typing.Self: if isinstance(default, Config): - default = default._to_dict() + default = default.to_dict(serialized=False) else: default = copy.deepcopy(default) for update in updates: if isinstance(update, Config): - update = update._to_dict(format_=_ConfigDictFormat.tuple) + update = update.to_dict(serialized=False) + else: + update = copy.deepcopy(update) for keys, value in update.items(): set_nested_dict_value(default, keys, value, update_type) @@ -878,27 +855,15 @@ def _handle_renamed_field( def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? - self_dict = self._to_dict( - format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything - ) - other_dict = other._to_dict( - format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything - ) - compare = { - key: (self_dict.get(key, MISSING), other_dict.get(key, MISSING)) - for key in self_dict.keys() | other_dict.keys() - } - diff = { - key: (self_value, other_value) - for key, (self_value, other_value) in compare.items() - if self_value != other_value - } - if diff: - log( + self_dict = self.to_dict(verbose=FieldVerboseLevel.everything) + other_dict = other.to_dict(verbose=FieldVerboseLevel.everything) + errors = compare_nested(self_dict, other_dict) + if errors: + return log( f"Config diff:\n " + "\n ".join( f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" - for key, (self_value, other_value) in diff.items() + for key, (self_value, other_value) in errors.items() ), log_fn=log_fn, ) @@ -1005,7 +970,7 @@ def set_nested_dict_value[ else: d[key] = {} for key_, value_ in value.items(): - set_nested_dict_value(d, key_, value_, update_type) + set_nested_dict_value(d[key], key_, value_, update_type) elif isinstance(d.get(key), dict): raise ValueError("Cannot replace a dict with a non-dict value.") elif ( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9c5e6f13..0958f118 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -505,7 +505,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: dataset_config = { "type": "fim", "dataset": dataset_config, - **self.fim.to_serialized(), + **self.fim.to_dict(), } # Legacy sampling config dataset_config = { diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f5d23031..25529ef0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -154,7 +154,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, - "config": self._config.to_serialized(), + "config": self._config.to_dict(), } self._load_yaml_data(yaml_data) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df..23e497bf 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -281,7 +281,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( - dataset_config.to_serialized(), + dataset_config.to_dict(), output_path.open("w"), ) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 503839f0..f27fff5d 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -32,7 +32,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: - serialized_metadata = metadata.to_serialized() + serialized_metadata = metadata.to_dict() if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( @@ -50,10 +50,8 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) - same_format = ( - loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) - and config.optimizer_state - ) + # Using `log_fn=bool` sets the output to true if the error list is non-empty. + same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 98cab927..654ba21f 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -232,7 +232,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada fast_llm_version=__version__, model=cls._model_class, format=config.format, - config=cls._model_class.from_dict({"base_model": imported_model_config.to_serialized()}), + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), shards=["weights"], ) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 87651dc4..f335015a 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -34,7 +34,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch huggingface_config = self._export_config(self._model.config.base_model) self._save_config(config.path, huggingface_config) return { - "fast_llm_metadata": metadata.to_serialized(), + "fast_llm_metadata": metadata.to_dict(), "model_config": huggingface_config, "format": "pt", } diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 5288d49f..71c83ece 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -71,7 +71,7 @@ def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metada def _serialize_metadata( self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata ) -> dict[str, typing.Any]: - return metadata.to_serialized() + return metadata.to_dict() def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 0ac46339..d6377409 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -147,8 +147,8 @@ def __init__( self._is_pipeline_parallel_main_rank = ( self._distributed_config.data_rank == 0 and self._distributed_config.tensor_rank == 0 ) - config_dict = config.to_serialized() - config_dict_verbose = config.to_serialized(verbose=FieldVerboseLevel.performance) + config_dict = config.to_dict() + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 2b240e4b..d4b46bcc 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -91,12 +91,12 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.everything) + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: out = super().to_diff_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.explicit) + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.explicit) return out def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None: diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index e3d421a3..185b89c2 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -40,7 +40,7 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): if wandb_path is not None: yaml.safe_dump(wandb_config, wandb_path.open("w")) # TODO: Does wandb work with nested configs? - self._wandb = wandb.init(config=experiment_config.to_serialized(), **wandb_config) + self._wandb = wandb.init(config=experiment_config.to_dict(), **wandb_config) else: self._wandb = None diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 4edd8b98..da083eef 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -289,3 +289,35 @@ def new_decorator(*args, **kwargs): return out return new_decorator + + +def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): + if errors is None: + errors = [] + # Check for equality of both values and types. + if type(config_a) != type(config_b): + errors.append(f"Type mismatch for key `{".".join(prefix)}`: {type(config_a)} != {type(config_b)}") + if isinstance(config_a, dict): + for key in config_a.keys() | config_b.keys(): + key_ = prefix + (key,) + if key not in config_a: + errors.append(f"Key `{".".join(key_)}` missing in lhs.") + elif key not in config_b: + errors.append(f"Key `{".".join(key_)}` missing in rhs.") + else: + compare_nested(config_a[key], config_b[key], errors, key_) + elif isinstance(config_a, (list, tuple, set)): + if len(config_a) != len(config_b): + errors.append(f"Length mismatch for key `{".".join(prefix)}`: {len(config_a)} != {len(config_b)}.") + else: + for i in range(len(config_a)): + compare_nested(config_a[i], config_b[i], errors, prefix + (str(i),)) + elif config_a != config_b and config_a is not config_b: + # `is not` needed for special cases like `math.nan` + errors.append(f"Different value for key `{".".join(prefix)}`: {config_a} != {config_b}.") + return errors + + +def check_equal_nested(config_a, config_b): + if errors := compare_nested(config_a, config_b): + raise ValueError("\n".join(errors)) diff --git a/tests/config/common.py b/tests/config/common.py index f9449507..a2657926 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -1,7 +1,10 @@ import enum import pathlib -from fast_llm.config import Config, Field, FieldHint, config_class +import pytest + +from fast_llm.config import Config, Field, FieldHint, ValidationError, config_class +from fast_llm.utils import check_equal_nested class ExampleEnum(enum.StrEnum): @@ -56,3 +59,29 @@ def _validate(self) -> None: @config_class class ExampleNestedConfig(ExampleConfig): nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + + +def check_config( + internal_config, + *alternate, + serialized_config=None, + cls: type[Config] = ExampleConfig, + fields: list[str] | None = None, +): + serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config + for init_config in (internal_config, *alternate): + config = cls.from_dict(init_config) + serialized_config_ = config.to_dict() + internal_config_ = config.to_dict(serialized=False) + if fields is None: + check_equal_nested(serialized_config_, serialized_config) + check_equal_nested(internal_config_, internal_config) + else: + for field in fields: + check_equal_nested(serialized_config_[field], serialized_config[field]) + check_equal_nested(internal_config_[field], internal_config[field]) + + +def check_invalid_config(config, cls: type[Config] = ExampleConfig): + with pytest.raises(ValidationError): + cls.from_dict(config) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index bed9c181..91b5c0d8 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -4,66 +4,13 @@ import numpy import pytest -from fast_llm.config import Config, FieldVerboseLevel, ValidationError -from fast_llm.utils import Assert -from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig - - -def _check_equal(config_a, config_b): - # Check for equality of both values and types. - Assert.eq(type(config_a), type(config_b)) - if isinstance(config_a, dict): - for key in config_a.keys() | config_b.keys(): - assert key in config_a and key in config_b, key - _check_equal(config_a[key], config_b[key]) - elif isinstance(config_a, (list, tuple, set)): - Assert.eq(len(config_a), len(config_b)) - for i in range(len(config_a)): - _check_equal(config_a[i], config_b[i]) - else: - try: - Assert.eq(config_a, config_b) - except AssertionError: - # Special case for `math.nan` - if config_a is not config_b: - raise - - -def check_equal(config_a, config_b): - try: - _check_equal(config_a, config_b) - except AssertionError as e: - raise AssertionError(config_a, config_b, *e.args) - - -def check_config( - internal_config, - *alternate, - serialized_config=None, - cls: type[Config] = ExampleConfig, - fields: list[str] | None = None, -): - serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config - for init_config in (internal_config, *alternate): - config = cls.from_dict(init_config) - serialized_config_ = config.to_serialized() - internal_config_ = config._to_dict() - if fields is None: - check_equal(serialized_config_, serialized_config) - check_equal(internal_config_, internal_config) - else: - for field in fields: - check_equal(serialized_config_[field], serialized_config[field]) - check_equal(internal_config_[field], internal_config[field]) - - -def check_invalid_config(config, cls: type[Config] = ExampleConfig): - with pytest.raises(ValidationError): - cls.from_dict(config) +from fast_llm.config import FieldVerboseLevel +from fast_llm.utils import Assert, check_equal_nested +from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig, check_config, check_invalid_config def test_create_and_serialize_config(): - Assert.eq(ExampleConfig.from_dict({}).to_serialized(), {}) + Assert.eq(ExampleConfig.from_dict({}).to_dict(), {}) @pytest.mark.parametrize("value", (0, -6, 3)) @@ -230,7 +177,7 @@ def test_enum_field_invalid(value): def test_core_field(): - Assert.eq(ExampleConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + Assert.eq(ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core), {"core_field": 4}) @pytest.mark.parametrize( @@ -263,8 +210,8 @@ def test_verbose_config_default(): "explicit_field": "explicit", } config = ExampleVerboseConfig.from_dict({}) - check_equal(config.to_serialized(), default_values) - check_equal(config._to_dict(), default_values) + check_equal_nested(config.to_dict(), default_values) + check_equal_nested(config.to_dict(serialized=False), default_values) @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) diff --git a/tests/config/test_update.py b/tests/config/test_update.py new file mode 100644 index 00000000..ad534d49 --- /dev/null +++ b/tests/config/test_update.py @@ -0,0 +1,52 @@ +import pytest + +from fast_llm.config import UpdateType +from fast_llm.utils import check_equal_nested +from tests.config.common import ExampleNestedConfig + +TEST_CONFIGS = ( + ( + # Empty config + {}, + {}, + {}, + None, + ), + ( + # Update unset field; don't update set field; update + {"int_field": 4, "str_field": "text"}, + {"float_field": 3.0, "str_field": ""}, + {"int_field": 4, "float_field": 3.0, "str_field": ""}, + None, + ), + ( + # Update/override nested field. + {"nested_field": {"int_field": 4, "str_field": "text"}}, + {"nested_field": {"float_field": 3.0, "str_field": ""}}, + {"nested_field": {"int_field": 4, "float_field": 3.0, "str_field": ""}}, + {"nested_field": {"float_field": 3.0, "str_field": ""}}, + ), + # TODO: Add more complex cases +) + + +@pytest.mark.parametrize(("base", "update", "updated", "overridden"), TEST_CONFIGS) +def test_update(base, update, updated, overridden) -> None: + if overridden is None: + overridden = updated + check_equal_nested(ExampleNestedConfig.from_dict(base, update, update_type=UpdateType.update).to_dict(), updated) + check_equal_nested( + ExampleNestedConfig.from_dict(base) + .to_copy(ExampleNestedConfig.from_dict(update), update_type=UpdateType.update) + .to_dict(), + updated, + ) + check_equal_nested( + ExampleNestedConfig.from_dict(base, update, update_type=UpdateType.override).to_dict(), overridden + ) + check_equal_nested( + ExampleNestedConfig.from_dict(base) + .to_copy(ExampleNestedConfig.from_dict(update), update_type=UpdateType.override) + .to_dict(), + overridden, + ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index a6fd3246..9dd7975c 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -95,20 +95,20 @@ def test_split_dataset(): {"training": 3, "validation": 1}, pathlib.Path("."), ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, { "training": { "type": "slice", - "dataset": dataset_config_0.to_serialized(), + "dataset": dataset_config_0.to_dict(), "begin": 0, "end": 0.75, }, "validation": { "type": "slice", - "dataset": dataset_config_0.to_serialized(), + "dataset": dataset_config_0.to_dict(), "begin": 0.75, "end": 1, }, @@ -124,13 +124,13 @@ def test_split_datasets_0(): {"training": 1, "validation": 1}, pathlib.Path("."), ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, { - "training": dataset_config_0.to_serialized(), - "validation": dataset_config_1.to_serialized(), + "training": dataset_config_0.to_dict(), + "validation": dataset_config_1.to_dict(), }, ) @@ -141,7 +141,7 @@ def test_split_datasets_1(): config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, @@ -149,10 +149,10 @@ def test_split_datasets_1(): "training": { "type": "blended", "datasets": [ - dataset_config_0.to_serialized(), + dataset_config_0.to_dict(), { "type": "slice", - "dataset": dataset_config_1.to_serialized(), + "dataset": dataset_config_1.to_dict(), "begin": 0, "end": 0.5, }, @@ -161,7 +161,7 @@ def test_split_datasets_1(): }, "validation": { "type": "slice", - "dataset": dataset_config_1.to_serialized(), + "dataset": dataset_config_1.to_dict(), "begin": 0.5, "end": 1, }, diff --git a/tests/test_config.py b/tests/test_config.py index 5c45db0b..79437e9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -118,4 +118,4 @@ def test_add_attn_dense_bias(): def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. - assert cls.from_dict({}).to_serialized() == {} + assert cls.from_dict({}).to_dict() == {} diff --git a/tools/moe_add_experts.py b/tools/moe_add_experts.py index 975ece86..69311017 100644 --- a/tools/moe_add_experts.py +++ b/tools/moe_add_experts.py @@ -93,7 +93,7 @@ def run(self): model.save_pretrained(self.output_dir, state_dict=state_dict) # Save surgery config as yaml - yaml.safe_dump(self.to_serialized(), (self.output_dir / "surgery_config.yaml").open("w")) + yaml.safe_dump(self.to_dict(), (self.output_dir / "surgery_config.yaml").open("w")) logger.info("Done!") From dded00af39930f7cc57ade985dd65e314e3b62a4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 20:19:15 -0400 Subject: [PATCH 12/18] fix --- fast_llm/config.py | 10 ++++------ fast_llm/utils.py | 5 +++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 0abd9073..62db786d 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -391,9 +391,11 @@ def _validate(self) -> None: if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue value = getattr(self, name) - if value is DEFAULT: + if isinstance(value, Tag): + Assert.is_(value, DEFAULT) # Replace the value with its default. # We still need to validate because some fields have invalid defaults. + # TODO: Improve (still needed with new config update format? Do earlier to allow implicit defaults?) value = field.default new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) setattr(self, name, new_value) @@ -860,11 +862,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ errors = compare_nested(self_dict, other_dict) if errors: return log( - f"Config diff:\n " - + "\n ".join( - f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" - for key, (self_value, other_value) in errors.items() - ), + f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index da083eef..a8c5eac6 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -71,12 +71,17 @@ def rms_diff(x: "torch.Tensor", y: "torch.Tensor") -> "torch.Tensor": class Tag: + __slots__ = ("value",) + def __init__(self, value: str): self.value = value def __repr__(self) -> str: return self.value + def __deepcopy__(self, memodict: dict[str, typing.Any]) -> typing.Self: + return self + class Assert: """ From 986f9f3c9a5ebdc40dd9879540449a0fdb2aa80f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 20:27:32 -0400 Subject: [PATCH 13/18] fix --- tests/test_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d5685a71..d446f414 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -409,6 +409,7 @@ def test_load_pretrained_distributed_with_config(): ) +@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" From 8e3e7957b759d17c194d78edf736af7136d0586d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 21:21:25 -0400 Subject: [PATCH 14/18] fixes --- fast_llm/engine/checkpoint/distributed.py | 2 +- tests/common.py | 4 ++-- tests/test_checkpoint.py | 3 --- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index e3cd7d16..4225a404 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -48,7 +48,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. - same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool) + same_format = config.optimizer_state and not metadata.config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) diff --git a/tests/common.py b/tests/common.py index 14ec5c61..cc749901 100644 --- a/tests/common.py +++ b/tests/common.py @@ -54,7 +54,7 @@ "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.init_method_std=0.022", + # "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -101,7 +101,7 @@ "--global-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.022", + "--init-method-std=0.0625", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d446f414..6793a670 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -259,7 +259,6 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=DistributedCheckpointFormat, optimizer_state=True, - load_config=ModelConfigType.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config.base_model, model.config.base_model) @@ -409,7 +408,6 @@ def test_load_pretrained_distributed_with_config(): ) -@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" @@ -454,7 +452,6 @@ def test_load_pretrained_in_dp2_match_checkpoint(): assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa -@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.slow @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): From da6eb7bf7b16b709c81f06df50a5cac342ee7915 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 01:17:37 -0400 Subject: [PATCH 15/18] fixes --- fast_llm/data/dataset/gpt/sampled.py | 4 +- fast_llm/engine/checkpoint/config.py | 30 +++++- fast_llm/engine/checkpoint/distributed.py | 24 +++-- fast_llm/engine/checkpoint/external.py | 2 +- fast_llm/engine/checkpoint/huggingface.py | 5 +- fast_llm/engine/checkpoint/state_dict.py | 4 +- fast_llm/engine/huggingface/config.py | 5 +- fast_llm/engine/multi_stage/fast_llm_model.py | 7 +- fast_llm/engine/training/trainer.py | 1 + tests/common.py | 6 +- tests/test_checkpoint.py | 95 +++++++++++-------- tests/test_config.py | 11 ++- 12 files changed, 124 insertions(+), 70 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index c96eb35f..fa486216 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -65,7 +65,7 @@ def __getitem__(self, item: typing.Any) -> np.ndarray: def _lazy_load(self): if self._array is None: - assert self.exists() + assert self.exists(), self._path self._array = np.load(self._path, mmap_mode="r") @@ -432,7 +432,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + if (unshuffled_tokens := data.get("unshuffled_tokens")) is not None: self._unshuffled_tokens = unshuffled_tokens else: self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 7dbd5ce7..55440a5c 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -202,6 +202,17 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): _abstract = False + load_config: ModelConfigType = Field( + default=ModelConfigType.model, + desc="Configuration to save/load.", + hint=FieldHint.core, + ) + + def _validate(self) -> None: + super()._validate() + if self.format.enforce_architecture_match: + assert self.load_config.load_architecture + @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): @@ -225,8 +236,23 @@ def __init__(self, model: "FastLLMModel"): # TODO: save_metadata? @classmethod - @abc.abstractmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": + updates = {} + metadata = cls._load_metadata(config) + if not config.load_config.load_fast_llm: + updates[("config", "multi_stage")] = {} + updates[("config", "distributed")] = {} + if not config.load_config.load_architecture: + updates[("config", "base_model")] = {} + elif not config.load_config.load_base_model: + updates[("config", "base_model")] = metadata.config.base_model.get_architecture() + if updates: + metadata = metadata.to_copy(updates) + return metadata + + @classmethod + @abc.abstractmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": pass @abc.abstractmethod @@ -234,7 +260,7 @@ def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"): pass @abc.abstractmethod - def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"): + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: pass def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]: diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 4225a404..ac06df5c 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -13,6 +13,7 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, DistributedCheckpointFormat, + ModelConfigType, export_safetensors_metadata, ) from fast_llm.engine.checkpoint.safe_load import SafeLoad @@ -27,7 +28,7 @@ class DistributedCheckpointHandler(CheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: @@ -40,15 +41,16 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No metadata=export_safetensors_metadata(serialized_metadata), ) - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: More safety checks + loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm})) shard_names = self.get_shard_names(config) # Make sure all shards to load are in the checkpoint. - Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) - Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) + Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards)) + Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. - same_format = config.optimizer_state and not metadata.config.compare(self._model.config, log_fn=bool) + same_format = config.optimizer_state and not loaded_metadata.config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) @@ -67,7 +69,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) for shard_name in shard_names: self._model.get_shard(shard_name).copy_( - f.get_slice("state_shard")[metadata.shards.index(shard_name)] + f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] ) else: # TODO: Does this copy twice? @@ -76,11 +78,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No else: log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info) - self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) + self._model.config.base_model.compare_architecture(loaded_metadata.config.base_model, logger.warning) with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context: - for rank in range(metadata.config.distributed.world_size): + for rank in range(loaded_metadata.config.distributed.world_size): loaded_model = self._model.__class__( - metadata.config.to_copy({("distributed", "rank"): rank}), + loaded_metadata.config.to_copy({("distributed", "rank"): rank}), optimizer_state_names=shard_names[1:], verbose=False, ) @@ -94,7 +96,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No # TODO v0.3: Use checkpoint version? Drop support? log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) loaded_shards = { - shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)] + shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] for shard_name in shard_names } else: @@ -119,3 +121,5 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No ) context.mark_as_loaded(counter.item()) + + return loaded_metadata.metadata diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 654ba21f..e3b6dcf2 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -226,7 +226,7 @@ def __init__(self, model: "FastLLMModel"): } @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: imported_model_config = cls._import_config(cls._load_config(config.path), True) return CheckpointMetadata( fast_llm_version=__version__, diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 7357b722..a5777d45 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -39,10 +39,11 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch "format": "pt", } - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: assert not config.optimizer_state + metadata = self._model.config.load_metadata(config) self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) - super().load(config, metadata) + super().load(config) @classmethod def get_huggingface_model_type(self) -> str: diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 71c83ece..1bb47e5c 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -73,7 +73,7 @@ def _serialize_metadata( ) -> dict[str, typing.Any]: return metadata.to_dict() - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from # `state_dict` that are ready for conversion, @@ -116,7 +116,7 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = FastLLMCheckpointFormat @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: path = config.path / f"metadata.yaml" logger.warning(f"Loading metadata from {path}") return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 08070804..d4b46bcc 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -74,7 +74,10 @@ def _get_config_dict( torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: updates[("distributed", "training_dtype")] = torch_dtype - fast_llm_config = cls.model_config_class.from_dict(metadata.config, kwargs.pop("fast_llm_config", {}), updates) + fast_llm_config = cls.model_config_class.from_metadata( + pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates + ) + config_dict = {"fast_llm_config": fast_llm_config} return config_dict, kwargs diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index e2255faa..de26f9bf 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -31,16 +31,15 @@ def save_checkpoint( ) converter.save(config, fast_llm_metadata) - def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any]: + def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Simplify branching. # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - metadata = self.config_class.load_metadata(config) converter = config.format.get_handler_class()(self) - converter.load(config, metadata) + metadata = converter.load(config) self._finalize_load(reset_optimizer=not config.optimizer_state) - return metadata.metadata + return metadata @classmethod def from_pretrained( diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index f2ed4a38..c6daa081 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -494,6 +494,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> metadata = self._multi_stage.load_checkpoint( config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) ) + assert metadata is not None self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. diff --git a/tests/common.py b/tests/common.py index cc749901..9ecb60ff 100644 --- a/tests/common.py +++ b/tests/common.py @@ -54,7 +54,7 @@ "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", - # "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -101,7 +101,7 @@ "--global-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.0625", + "--init-method-std=0.022", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", @@ -394,7 +394,7 @@ def run_test_script( if num_gpus == 1 and not is_megatron: CliTrainingConfig.parse_and_run(script) else: - completed_proc = subprocess.run(command, env=env) + completed_proc = subprocess.run(command, env=env, timeout=30) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") if compare: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 6793a670..4171581a 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,7 +14,7 @@ FastLLMCheckpointFormat, ModelConfigType, ) -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -246,8 +246,12 @@ def test_converted_huggingface(): assert (h0[key] == h1[key]).all() -def _compare_configs(config_ref, config_test): - config_ref.compare(config_test) +def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): + config_ref.base_model.compare(config_test.base_model) + + +def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): + config_ref.base_model.get_architecture().compare(config_test.base_model.get_architecture()) @pytest.mark.depends(on=["test_converted_distributed"]) @@ -261,7 +265,7 @@ def test_load_pretrained_distributed_checkpoint(): optimizer_state=True, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) - _compare_configs(config.base_model, model.config.base_model) + _compare_model_configs(config, model.config) state_shards = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) ) @@ -271,20 +275,24 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) - pretrained_config_0 = CheckpointLoadConfig( - path=_CONVERT_PATH / "distributed_0", - format=DistributedCheckpointFormat, + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) ) - pretrained_config_1 = CheckpointLoadConfig( - path=_CONVERT_PATH / "distributed_1", - format=DistributedCheckpointFormat, + + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "distributed_0", + format=DistributedCheckpointFormat, + ) ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "distributed_1", + format=DistributedCheckpointFormat, + ) + ) + _compare_architectures(config_ref, model.config) + _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -293,14 +301,17 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) - pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) - pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + ) + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) + ) + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) + ) + _compare_architectures(config_ref, model.config) + _compare_architectures(config_ref, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -309,23 +320,27 @@ def test_load_converted_fast_llm_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig( - path=_CKPT_PATH, - format=DistributedCheckpointFormat, + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + ) ) - pretrained_config_0 = CheckpointLoadConfig( - path=_CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_1", + format=HUGGINGFACE_CHECKPOINT_FORMAT, + ), + mode=StageMode.weights, ) - pretrained_config_1 = CheckpointLoadConfig( - path=_CONVERT_PATH / "huggingface_1", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_0", + format=HUGGINGFACE_CHECKPOINT_FORMAT, + ) ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + _compare_architectures(config_ref, model.config) + _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -423,7 +438,7 @@ def test_load_pretrained_in_dp2_match_checkpoint(): ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) - _compare_configs(config_ref.base_model, config_test.base_model) + _compare_model_configs(config_ref, config_test) shards_ref = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] ref_model = TEST_MODEL_CLS(config_ref) @@ -467,7 +482,7 @@ def test_load_distributed_checkpoint_dp2(): ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) - _compare_configs(config.base_model, model.config.base_model) + _compare_model_configs(config, model.config) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] diff --git a/tests/test_config.py b/tests/test_config.py index 79437e9d..ed758965 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,6 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.utils import check_equal_nested def run_without_import(cmd: str): @@ -114,8 +116,11 @@ def test_add_attn_dense_bias(): ) -@pytest.mark.parametrize("cls", (GPTSamplingConfig,)) -def test_serialize_default_config_updates(cls): +@pytest.mark.parametrize( + ("cls", "default"), + ((GPTSamplingConfig, {}), (GPTModelConfig, {"distributed": {"world_size": 1, "rank": 0, "local_world_size": 1}})), +) +def test_serialize_default_config_updates(cls, default): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. - assert cls.from_dict({}).to_dict() == {} + check_equal_nested(cls.from_dict({}).to_dict(), default) From baad705d6960d9578a2f5e29664284250d569980 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 19:16:01 -0400 Subject: [PATCH 16/18] fix --- fast_llm/layers/transformer/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a1cb658e..cf409e77 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -186,7 +186,7 @@ class TransformerSubLayerName(str, enum.Enum): @config_class() class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( - default_factory=lambda: [TransformerSubLayerName.query, TransformerSubLayerName.value_], + default=None, desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) @@ -220,6 +220,15 @@ def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": return parameter def _validate(self) -> None: + if self.layers is None: + with self._set_implicit_default(): + # Setting the default layers only whee PeFT is enabled + # so they don't appear when serializing the default transformer config. + self.layers = ( + [TransformerSubLayerName.query, TransformerSubLayerName.value_] + if self.type == PeftType.lora + else [] + ) if self.type != PeftType.none: if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: # TODO: Add MLP support. From b7028378a2f8cb4e6e863ac55af69b0f11f71cff Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 4 Apr 2025 21:48:10 -0400 Subject: [PATCH 17/18] Test, fixes --- fast_llm/engine/checkpoint/config.py | 11 +-- fast_llm/engine/checkpoint/distributed.py | 7 ++ fast_llm/engine/checkpoint/huggingface.py | 6 +- fast_llm/engine/checkpoint/state_dict.py | 17 ++++- fast_llm/engine/multi_stage/config.py | 14 ++-- tests/test_config.py | 84 ++++++++++++++++++++++- 6 files changed, 123 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 55440a5c..62928ed0 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -201,9 +201,9 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): _abstract = False - + # TODO: Set default to model? (Not backward compatible) load_config: ModelConfigType = Field( - default=ModelConfigType.model, + default=ModelConfigType.architecture, desc="Configuration to save/load.", hint=FieldHint.core, ) @@ -233,7 +233,10 @@ class CheckpointHandler(abc.ABC): def __init__(self, model: "FastLLMModel"): self._model = model - # TODO: save_metadata? + @classmethod + @abc.abstractmethod + def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: "CheckpointMetadata"): + pass @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -245,7 +248,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetad if not config.load_config.load_architecture: updates[("config", "base_model")] = {} elif not config.load_config.load_base_model: - updates[("config", "base_model")] = metadata.config.base_model.get_architecture() + updates[("config", "base_model")] = metadata.config.base_model.get_architecture().to_dict() if updates: metadata = metadata.to_copy(updates) return metadata diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index ac06df5c..de1625f6 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -12,6 +12,7 @@ CheckpointLoadConfig, CheckpointLoadMetadataConfig, CheckpointSaveConfig, + CheckpointSaveMetadataConfig, DistributedCheckpointFormat, ModelConfigType, export_safetensors_metadata, @@ -27,6 +28,12 @@ class DistributedCheckpointHandler(CheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat + @classmethod + def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): + config.path.mkdir(parents=True, exist_ok=True) + serialized_metadata = metadata.to_dict() + yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index a5777d45..2972a4fa 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -20,8 +20,10 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): - def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: - path = config.path / f"{self.base_file_name}.safetensors.index.json" + @classmethod + def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: + config.path.mkdir(parents=True, exist_ok=True) + path = config.path / f"{cls.base_file_name}.safetensors.index.json" logger.info(f"Saving index to {path}") # Save the index. json.dump( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 1bb47e5c..556e97be 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -30,6 +30,13 @@ class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" + @classmethod + def save_metadata( + cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata, index: dict | None = None + ): + serialized_metadata = cls._serialize_metadata(config, metadata) + cls._save_serialized_metadata(config, serialized_metadata, {} if index is None else index) + def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = self._serialize_metadata(config, metadata) saver = StateDictSaver( @@ -64,12 +71,14 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: self._save_serialized_metadata(config, serialized_metadata, index) + @classmethod @abc.abstractmethod - def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: + def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: pass + @classmethod def _serialize_metadata( - self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata + cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata ) -> dict[str, typing.Any]: return metadata.to_dict() @@ -121,9 +130,11 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad logger.warning(f"Loading metadata from {path}") return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) + @classmethod def _save_serialized_metadata( - self, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict + cls, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict ) -> None: + config.path.mkdir(parents=True, exist_ok=True) path = config.path / f"metadata.yaml" logger.info(f"Saving metadata to {path}") if "metadata" not in serialized_metadata: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 43b412fb..6a0c8813 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -187,11 +187,12 @@ class MultiStageConfig(StageConfig): def _validate(self) -> None: super()._validate() if self.zero_stage is not None: - Assert.in_range_incl(self.zero_stage, 1, 3) - if self.zero_stage >= 2: - self.num_grad_buffers = 2 - if self.zero_stage >= 3: - self.num_weight_buffers = 2 + with self._set_implicit_default(): + Assert.in_range_incl(self.zero_stage, 1, 3) + if self.zero_stage >= 2: + self.num_grad_buffers = 2 + if self.zero_stage >= 3: + self.num_weight_buffers = 2 if self.num_grad_buffers is not None: Assert.geq(self.num_grad_buffers, 1) if self.num_weight_buffers is not None: @@ -281,6 +282,9 @@ def to_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> "Checkp **kwargs, ) + def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None: + self.get_checkpoint_handler_class(config.format).save_metadata(config, self.to_metadata(config, **kwargs)) + @config_class() class PretrainedFastLLMModelConfig(Config): diff --git a/tests/test_config.py b/tests/test_config.py index ed758965..79c6738d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,13 +5,16 @@ import pytest import yaml +from fast_llm.config import NoAutoValidate from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry -from fast_llm.models.gpt.config import GPTModelConfig -from fast_llm.utils import check_equal_nested +from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.utils import Assert, check_equal_nested +from tests.common import TEST_RESULTS_PATH def run_without_import(cmd: str): @@ -124,3 +127,80 @@ def test_serialize_default_config_updates(cls, default): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. check_equal_nested(cls.from_dict({}).to_dict(), default) + + +@pytest.mark.parametrize("load_config", tuple(ModelConfigType)) +def test_pretrained_config(load_config: ModelConfigType): + config_path = TEST_RESULTS_PATH / "pretrained_config" + pretrained_model_config = GPTModelConfig.from_dict( + { + "base_model": { + "transformer": { + "normalization": {"type": "rms_norm"}, # Nested + "rotary": {"type": "default"}, + "num_layers": 12, # Default + "hidden_size": 1024, # Default + "window_size": 32, # Non-architecture + "ffn_hidden_size": 4096, # Implicit default, default value + "activation_type": "silu", # Implicit default, non-default value + "head_groups": 4, + }, + "tie_word_embeddings": False, + }, + "multi_stage": {"zero_stage": 3}, + "distributed": {"training_dtype": "bfloat16"}, + } + ) + with NoAutoValidate(): + save_config = CheckpointSaveMetadataConfig.from_dict({"format": "fast_llm", "path": config_path}) + save_config.setup(GPTModelConfig) + save_config.validate() + pretrained_model_config.save_metadata(save_config) + + base_model_update = { + "transformer": { + # rotary: Don't override nested. + "normalization": {"implementation": "triton"}, # Update non-default nested + "peft": {"freeze_others": False}, # Update default nested, non-architecture + "hidden_size": 512, # Override, affects derived value (kv channels) + "head_groups": 1, # Override to default + }, + "vocab_size": 1000, + } + pretrained_config = PretrainedGPTModelConfig.from_dict( + { + "model": { + "base_model": base_model_update, + "distributed": {"seed": 1234, "training_dtype": "float16"}, + }, + "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, + } + ) + Assert.eq(pretrained_config.model.base_model.transformer.kv_channels, 64) + serialized_config = pretrained_config.model.to_dict() + expected_config = {"distributed": DistributedConfig().to_dict()} + + if load_config == ModelConfigType.fast_llm: + expected_config["multi_stage"] = {"zero_stage": 3} + expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) + if load_config in (ModelConfigType.architecture, ModelConfigType.fast_llm, ModelConfigType.model): + expected_config["base_model"] = { + "transformer": { + "normalization": {"type": "rms_norm", "implementation": "triton"}, + "rotary": {"type": "default"}, + "peft": {"freeze_others": False}, + "num_layers": 12, + "hidden_size": 512, + "ffn_hidden_size": 4096, + "activation_type": "silu", + "head_groups": 1, + }, + "tie_word_embeddings": False, + "vocab_size": 1000, + } + if load_config != ModelConfigType.architecture: + expected_config["base_model"]["transformer"]["window_size"] = 32 + else: + expected_config["base_model"] = base_model_update + + check_equal_nested(serialized_config, expected_config) From cff9892d44a9380a992f33692500ed7e08191824 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 18:11:05 -0400 Subject: [PATCH 18/18] fixes --- fast_llm/engine/inference/huggingface.py | 4 ++- tests/test_checkpoint.py | 33 +++++++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 75aea9dd..196310b4 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -60,7 +60,9 @@ def from_pretrained( updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained(pretrained_model_name_or_path, updates, mode=mode) + fast_llm_model = cls.runner_class.model_class.from_pretrained( + pretrained_model_name_or_path, updates, mode=mode + ) config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 4171581a..0c5e177d 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -263,6 +263,7 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=DistributedCheckpointFormat, optimizer_state=True, + load_config=ModelConfigType.model, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_model_configs(config, model.config) @@ -276,19 +277,25 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, + ) ) model = TEST_MODEL_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) @@ -302,13 +309,25 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, + ) ) model = TEST_MODEL_CLS.from_pretrained( - CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) + CheckpointLoadConfig( + path=_CONVERT_PATH / "fast_llm_0", + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) ) config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) + CheckpointLoadConfig( + path=_CONVERT_PATH / "fast_llm_1", + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) ) _compare_architectures(config_ref, model.config) _compare_architectures(config_ref, config_alt) @@ -324,12 +343,14 @@ def test_load_converted_huggingface_checkpoint(): CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) model = TEST_MODEL_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ), mode=StageMode.weights, ) @@ -337,6 +358,7 @@ def test_load_converted_huggingface_checkpoint(): CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) @@ -353,6 +375,7 @@ def test_run_converted_model(): CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) test_input = torch.randint( @@ -364,6 +387,7 @@ def test_run_converted_model(): CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ) ) errors = [] @@ -479,6 +503,7 @@ def test_load_distributed_checkpoint_dp2(): pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights)