Skip to content

Commit d8f3390

Browse files
authored
Fix llama conversion, improve parameter conversion (#94)
1 parent a19d40b commit d8f3390

File tree

6 files changed

+275
-166
lines changed

6 files changed

+275
-166
lines changed

fast_llm/config.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_AUTO_VALIDATE = True
1818

1919
MISSING = Tag("<MISSING>")
20+
DEFAULT = Tag("<DEFAULT>")
2021

2122

2223
class NoAutoValidate:
@@ -347,6 +348,10 @@ def _validate(self):
347348
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
348349
continue
349350
value = getattr(self, name)
351+
if value is DEFAULT:
352+
# Replace the value with its default.
353+
# We still need to validate because some fields have invalid defaults.
354+
value = field.default
350355
new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False)
351356
setattr(self, name, new_value)
352357
for name in getattr(self, "_unknown_fields", {}):
@@ -603,7 +608,9 @@ def _add_field_to_args(
603608
field_value = field_value.__fast_llm_serialize__()
604609
if isinstance(value, enum.Enum):
605610
field_value = field_value.value
606-
elif not isinstance(value, int | float | bool | str | None):
611+
# Tag is not actually serializable, but needs to be kept as-is for config processing,
612+
# and should be absent for valid configs.
613+
elif not isinstance(value, int | float | bool | str | Tag | None):
607614
field_value = str(field_value)
608615
if format_ == _ConfigDictFormat.tuple:
609616
field_value = {(): field_value}

fast_llm/engine/checkpoint/external.py

+110-48
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from fast_llm import __version__
12+
from fast_llm.config import MISSING
1213
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig
1314
from fast_llm.engine.checkpoint.config import (
1415
CheckpointLoadConfig,
@@ -24,65 +25,104 @@
2425
logger = logging.getLogger(__name__)
2526

2627

27-
@dataclasses.dataclass
28-
class ParamConverter:
29-
fast_llm_name: tuple[str, ...] | None
30-
export_name: tuple[str, ...] | str | None
28+
@dataclasses.dataclass(kw_only=True)
29+
class ParamConverter(abc.ABC):
30+
fast_llm_names: tuple[tuple[str, ...], ...] = () # Array of fast-llm names, in nested (tuple) format.
31+
export_names: tuple[tuple[str, ...], ...] = () # Array of export names, in nested (tuple) format.
3132

32-
def export_param(self, fast_llm_value):
33-
return fast_llm_value
33+
@abc.abstractmethod
34+
def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
35+
pass
36+
37+
@abc.abstractmethod
38+
def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
39+
pass
40+
41+
42+
@dataclasses.dataclass(kw_only=True)
43+
class RenameParamConverter(ParamConverter):
3444

35-
def import_param(self, export_value):
36-
return export_value
45+
def __post_init__(self):
46+
Assert.eq(len(self.fast_llm_names), 1)
47+
Assert.eq(len(self.export_names), 1)
3748

49+
def export_params(self, fast_llm_values):
50+
return fast_llm_values
3851

39-
@dataclasses.dataclass
52+
def import_params(self, export_values):
53+
return export_values
54+
55+
56+
# def __repr__(self):
57+
# return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})"
58+
59+
60+
@dataclasses.dataclass(kw_only=True)
4061
class ConstantImportParamConverter(ParamConverter):
41-
fast_llm_value: typing.Any
62+
fast_llm_value: typing.Any = MISSING
63+
64+
def __post_init__(self):
65+
Assert.eq(len(self.fast_llm_names), 1)
66+
Assert.eq(len(self.export_names), 0)
4267

43-
def export_param(self, fast_llm_value):
44-
Assert.eq(fast_llm_value, self.fast_llm_value)
68+
def export_params(self, fast_llm_values):
69+
Assert.eq(fast_llm_values[0], self.fast_llm_value)
70+
return ()
4571

46-
def import_param(self, export_value):
47-
return self.fast_llm_value
72+
def import_params(self, export_values):
73+
return (self.fast_llm_value,)
4874

4975

50-
@dataclasses.dataclass
76+
@dataclasses.dataclass(kw_only=True)
5177
class ConstantExportParamConverter(ParamConverter):
52-
export_value: typing.Any
78+
export_value: typing.Any = MISSING
5379

54-
def export_param(self, fast_llm_value):
55-
return self.export_value
80+
def __post_init__(self):
81+
Assert.eq(len(self.fast_llm_names), 0)
82+
Assert.eq(len(self.export_names), 1)
5683

57-
def import_param(self, export_value):
58-
Assert.eq(export_value, self.export_value)
84+
def export_params(self, fast_llm_values):
85+
return (self.export_value,)
86+
87+
def import_params(self, export_values):
88+
Assert.eq(export_values[0], self.export_value)
89+
return ()
5990

6091

61-
@dataclasses.dataclass
92+
@dataclasses.dataclass(kw_only=True)
6293
class IgnoreImportParamConverter(ParamConverter):
63-
ignore_export_value: typing.Any
94+
ignore_export_value: typing.Any = MISSING
6495

65-
def export_param(self, fast_llm_value):
66-
pass
96+
def __post_init__(self):
97+
Assert.eq(len(self.fast_llm_names), 0)
98+
Assert.eq(len(self.export_names), 1)
6799

68-
def import_param(self, export_value):
69-
if export_value is not self.ignore_export_value:
100+
def export_params(self, fast_llm_values):
101+
return (MISSING,)
102+
103+
def import_params(self, export_values):
104+
if export_values[0] not in (self.ignore_export_value, MISSING):
70105
logger.warning(
71-
f"The configuration parameter `{self.export_name}={export_value}` is ignored during conversion."
106+
f"The configuration parameter `{self.export_names[0]}={export_values[0]}` is ignored during conversion."
72107
f" If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration."
73108
)
109+
return ()
74110

75111

76-
@dataclasses.dataclass
112+
@dataclasses.dataclass(kw_only=True)
77113
class MappedConfigParamConverter(ParamConverter):
78-
fast_llm_value: typing.Callable
79-
export_value: typing.Callable
114+
fast_llm_value: typing.Callable = lambda x: x
115+
export_value: typing.Callable = lambda x: x
116+
117+
def __post_init__(self):
118+
Assert.eq(len(self.fast_llm_names), 1)
119+
Assert.eq(len(self.export_names), 1)
80120

81-
def export_param(self, fast_llm_value):
82-
return self.export_value(fast_llm_value)
121+
def export_params(self, fast_llm_values):
122+
return (self.export_value(fast_llm_values[0]),)
83123

84-
def import_param(self, export_value):
85-
return self.fast_llm_value(export_value)
124+
def import_params(self, export_values):
125+
return (self.fast_llm_value(export_values[0]),)
86126

87127

88128
class WeightConverter:
@@ -197,13 +237,18 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
197237
# TODO v0.3: not used in this class
198238
exported_config = {}
199239
for converter in cls._get_config_converters():
200-
value = converter.export_param(
201-
None
202-
if converter.fast_llm_name is None
203-
else cls._get_fast_llm_attribute(config, converter.fast_llm_name) # Noqa
204-
)
205-
if converter.export_name is not None:
206-
set_nested_dict_value(exported_config, converter.export_name, value)
240+
try:
241+
values = converter.export_params(
242+
tuple(
243+
cls._get_fast_llm_attribute(config, fast_llm_name)
244+
for fast_llm_name in converter.fast_llm_names
245+
)
246+
)
247+
for export_name, value in zip(converter.export_names, values, strict=True):
248+
if value is not MISSING:
249+
set_nested_dict_value(exported_config, export_name, value)
250+
except Exception as e:
251+
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)
207252

208253
return exported_config # Noqa
209254

@@ -214,12 +259,25 @@ def _import_config(
214259
kwargs = {}
215260
for converter in cls._get_config_converters():
216261
try:
217-
value = None if converter.export_name is None else get_nested_dict_value(config, converter.export_name)
218-
except KeyError:
219-
value = None
220-
value = converter.import_param(value)
221-
if converter.fast_llm_name is not None:
222-
kwargs[converter.fast_llm_name] = value
262+
values = ()
263+
for export_name in converter.export_names:
264+
try:
265+
value = get_nested_dict_value(config, export_name)
266+
except KeyError:
267+
value = MISSING
268+
values = values + (value,)
269+
values = converter.import_params(values)
270+
for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True):
271+
if value is MISSING:
272+
# Missing values need to be handled in dedicated converters,
273+
# because implicit / default values may not match.
274+
# TODO: Different behavior from other uses of MISSING. Use different tag?
275+
raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}")
276+
if fast_llm_name in kwargs:
277+
raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}")
278+
kwargs[fast_llm_name] = value
279+
except Exception as e:
280+
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)
223281

224282
config_class = cls._model_class.get_base_model_config_class()
225283
if architecture_only:
@@ -335,7 +393,11 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str:
335393
@classmethod
336394
@abc.abstractmethod
337395
def _create_config_converters(cls) -> list[ParamConverter]:
338-
return [ConstantExportParamConverter(None, "model_type", cls.get_huggingface_model_type())]
396+
return [
397+
ConstantExportParamConverter(
398+
export_names=(("model_type",),), export_value=cls.get_huggingface_model_type()
399+
)
400+
]
339401

340402
@classmethod
341403
def _load_config(cls, directory: pathlib.Path | str) -> dict:

fast_llm/layers/transformer/config.py

-9
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,6 @@ def complex_format(self):
123123
return self.enabled and not self.triton
124124

125125
def _validate(self):
126-
# These happen during conversion.
127-
if self.scale_factor is None:
128-
self.scale_factor = 8.0
129-
if self.low_frequency_factor is None:
130-
self.low_frequency_factor = 1.0
131-
if self.high_frequency_factor is None:
132-
self.high_frequency_factor = 4.0
133-
if self.original_context_length is None:
134-
self.original_context_length = 8192
135126
super()._validate()
136127
if self.triton and not TritonConfig.TRITON_ENABLED:
137128
warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.")

0 commit comments

Comments
 (0)