9
9
import torch
10
10
11
11
from fast_llm import __version__
12
+ from fast_llm .config import MISSING
12
13
from fast_llm .engine .base_model .config import BaseModelArchitectureConfig
13
14
from fast_llm .engine .checkpoint .config import (
14
15
CheckpointLoadConfig ,
24
25
logger = logging .getLogger (__name__ )
25
26
26
27
27
- @dataclasses .dataclass
28
- class ParamConverter :
29
- fast_llm_name : tuple [str , ...] | None
30
- export_name : tuple [str , ...] | str | None
28
+ @dataclasses .dataclass ( kw_only = True )
29
+ class ParamConverter ( abc . ABC ) :
30
+ fast_llm_names : tuple [tuple [ str , ...], ...] = () # Array of fast-llm names, in nested (tuple) format.
31
+ export_names : tuple [tuple [ str , ...], ...] = () # Array of export names, in nested (tuple) format.
31
32
32
- def export_param (self , fast_llm_value ):
33
- return fast_llm_value
33
+ @abc .abstractmethod
34
+ def export_params (self , fast_llm_values : tuple [typing .Any , ...]) -> tuple [typing .Any , ...]:
35
+ pass
36
+
37
+ @abc .abstractmethod
38
+ def import_params (self , export_values : tuple [typing .Any , ...]) -> tuple [typing .Any , ...]:
39
+ pass
40
+
41
+
42
+ @dataclasses .dataclass (kw_only = True )
43
+ class RenameParamConverter (ParamConverter ):
34
44
35
- def import_param (self , export_value ):
36
- return export_value
45
+ def __post_init__ (self ):
46
+ Assert .eq (len (self .fast_llm_names ), 1 )
47
+ Assert .eq (len (self .export_names ), 1 )
37
48
49
+ def export_params (self , fast_llm_values ):
50
+ return fast_llm_values
38
51
39
- @dataclasses .dataclass
52
+ def import_params (self , export_values ):
53
+ return export_values
54
+
55
+
56
+ # def __repr__(self):
57
+ # return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})"
58
+
59
+
60
+ @dataclasses .dataclass (kw_only = True )
40
61
class ConstantImportParamConverter (ParamConverter ):
41
- fast_llm_value : typing .Any
62
+ fast_llm_value : typing .Any = MISSING
63
+
64
+ def __post_init__ (self ):
65
+ Assert .eq (len (self .fast_llm_names ), 1 )
66
+ Assert .eq (len (self .export_names ), 0 )
42
67
43
- def export_param (self , fast_llm_value ):
44
- Assert .eq (fast_llm_value , self .fast_llm_value )
68
+ def export_params (self , fast_llm_values ):
69
+ Assert .eq (fast_llm_values [0 ], self .fast_llm_value )
70
+ return ()
45
71
46
- def import_param (self , export_value ):
47
- return self .fast_llm_value
72
+ def import_params (self , export_values ):
73
+ return ( self .fast_llm_value ,)
48
74
49
75
50
- @dataclasses .dataclass
76
+ @dataclasses .dataclass ( kw_only = True )
51
77
class ConstantExportParamConverter (ParamConverter ):
52
- export_value : typing .Any
78
+ export_value : typing .Any = MISSING
53
79
54
- def export_param (self , fast_llm_value ):
55
- return self .export_value
80
+ def __post_init__ (self ):
81
+ Assert .eq (len (self .fast_llm_names ), 0 )
82
+ Assert .eq (len (self .export_names ), 1 )
56
83
57
- def import_param (self , export_value ):
58
- Assert .eq (export_value , self .export_value )
84
+ def export_params (self , fast_llm_values ):
85
+ return (self .export_value ,)
86
+
87
+ def import_params (self , export_values ):
88
+ Assert .eq (export_values [0 ], self .export_value )
89
+ return ()
59
90
60
91
61
- @dataclasses .dataclass
92
+ @dataclasses .dataclass ( kw_only = True )
62
93
class IgnoreImportParamConverter (ParamConverter ):
63
- ignore_export_value : typing .Any
94
+ ignore_export_value : typing .Any = MISSING
64
95
65
- def export_param (self , fast_llm_value ):
66
- pass
96
+ def __post_init__ (self ):
97
+ Assert .eq (len (self .fast_llm_names ), 0 )
98
+ Assert .eq (len (self .export_names ), 1 )
67
99
68
- def import_param (self , export_value ):
69
- if export_value is not self .ignore_export_value :
100
+ def export_params (self , fast_llm_values ):
101
+ return (MISSING ,)
102
+
103
+ def import_params (self , export_values ):
104
+ if export_values [0 ] not in (self .ignore_export_value , MISSING ):
70
105
logger .warning (
71
- f"The configuration parameter `{ self .export_name } ={ export_value } ` is ignored during conversion."
106
+ f"The configuration parameter `{ self .export_names [ 0 ] } ={ export_values [ 0 ] } ` is ignored during conversion."
72
107
f" If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration."
73
108
)
109
+ return ()
74
110
75
111
76
- @dataclasses .dataclass
112
+ @dataclasses .dataclass ( kw_only = True )
77
113
class MappedConfigParamConverter (ParamConverter ):
78
- fast_llm_value : typing .Callable
79
- export_value : typing .Callable
114
+ fast_llm_value : typing .Callable = lambda x : x
115
+ export_value : typing .Callable = lambda x : x
116
+
117
+ def __post_init__ (self ):
118
+ Assert .eq (len (self .fast_llm_names ), 1 )
119
+ Assert .eq (len (self .export_names ), 1 )
80
120
81
- def export_param (self , fast_llm_value ):
82
- return self .export_value (fast_llm_value )
121
+ def export_params (self , fast_llm_values ):
122
+ return ( self .export_value (fast_llm_values [ 0 ]), )
83
123
84
- def import_param (self , export_value ):
85
- return self .fast_llm_value (export_value )
124
+ def import_params (self , export_values ):
125
+ return ( self .fast_llm_value (export_values [ 0 ]), )
86
126
87
127
88
128
class WeightConverter :
@@ -197,13 +237,18 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
197
237
# TODO v0.3: not used in this class
198
238
exported_config = {}
199
239
for converter in cls ._get_config_converters ():
200
- value = converter .export_param (
201
- None
202
- if converter .fast_llm_name is None
203
- else cls ._get_fast_llm_attribute (config , converter .fast_llm_name ) # Noqa
204
- )
205
- if converter .export_name is not None :
206
- set_nested_dict_value (exported_config , converter .export_name , value )
240
+ try :
241
+ values = converter .export_params (
242
+ tuple (
243
+ cls ._get_fast_llm_attribute (config , fast_llm_name )
244
+ for fast_llm_name in converter .fast_llm_names
245
+ )
246
+ )
247
+ for export_name , value in zip (converter .export_names , values , strict = True ):
248
+ if value is not MISSING :
249
+ set_nested_dict_value (exported_config , export_name , value )
250
+ except Exception as e :
251
+ raise RuntimeError (f"Config conversion failed for converter { converter } " , * e .args )
207
252
208
253
return exported_config # Noqa
209
254
@@ -214,12 +259,25 @@ def _import_config(
214
259
kwargs = {}
215
260
for converter in cls ._get_config_converters ():
216
261
try :
217
- value = None if converter .export_name is None else get_nested_dict_value (config , converter .export_name )
218
- except KeyError :
219
- value = None
220
- value = converter .import_param (value )
221
- if converter .fast_llm_name is not None :
222
- kwargs [converter .fast_llm_name ] = value
262
+ values = ()
263
+ for export_name in converter .export_names :
264
+ try :
265
+ value = get_nested_dict_value (config , export_name )
266
+ except KeyError :
267
+ value = MISSING
268
+ values = values + (value ,)
269
+ values = converter .import_params (values )
270
+ for fast_llm_name , value in zip (converter .fast_llm_names , values , strict = True ):
271
+ if value is MISSING :
272
+ # Missing values need to be handled in dedicated converters,
273
+ # because implicit / default values may not match.
274
+ # TODO: Different behavior from other uses of MISSING. Use different tag?
275
+ raise ValueError (f"Missing converted value for fast-llm parameter { fast_llm_name } " )
276
+ if fast_llm_name in kwargs :
277
+ raise ValueError (f"Duplicate converted value for fast-llm parameter { fast_llm_name } " )
278
+ kwargs [fast_llm_name ] = value
279
+ except Exception as e :
280
+ raise RuntimeError (f"Config conversion failed for converter { converter } " , * e .args )
223
281
224
282
config_class = cls ._model_class .get_base_model_config_class ()
225
283
if architecture_only :
@@ -335,7 +393,11 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str:
335
393
@classmethod
336
394
@abc .abstractmethod
337
395
def _create_config_converters (cls ) -> list [ParamConverter ]:
338
- return [ConstantExportParamConverter (None , "model_type" , cls .get_huggingface_model_type ())]
396
+ return [
397
+ ConstantExportParamConverter (
398
+ export_names = (("model_type" ,),), export_value = cls .get_huggingface_model_type ()
399
+ )
400
+ ]
339
401
340
402
@classmethod
341
403
def _load_config (cls , directory : pathlib .Path | str ) -> dict :
0 commit comments