Skip to content

Commit 631332c

Browse files
authored
Merge pull request #1158 from vloncar/serialization
Infrastructure for saving/loading hls4ml models
2 parents 0c289ea + 2cd79b9 commit 631332c

File tree

18 files changed

+1045
-77
lines changed

18 files changed

+1045
-77
lines changed

docs/api/serialization.rst

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
============================
2+
Saving/Loading hls4ml models
3+
============================
4+
5+
``hls4ml`` model objects (instances of ``ModelGraph`` class) can be saved to disk and loaded at a later stage. The saved model doesn't require original Keras/PyTorch/ONNX model for loading.
6+
7+
To save/load a model use the following API:
8+
9+
.. code-block:: python
10+
11+
from hls4ml.converters import convert_from_keras_model, load_saved_model
12+
13+
model = convert_from_keras_model(keras_model, ...)
14+
15+
# Save a model to some path
16+
model.save('some/path/my_hls4ml_model.fml')
17+
18+
# Load a model from a file
19+
loaded_model = load_saved_model('some/path/my_hls4ml_model.fml')
20+
21+
22+
Saved model will have a ``.fml`` extension, but is in fact a gzipped tar archive. Loaded model can be used in the same way as the original one. This includes modification of certain config parameters, for example output directory, layer reuse factor etc.
23+
24+
Linking with existing project
25+
=============================
26+
27+
Once the project has been written to disk with ``ModelGraph.write()``, it can also be linked with at later stage. Similarly to loading a saved model, this feature allows skipping the conversion step. Additionally, it may be used to test manual changes to the generated project.
28+
29+
Linking function will create a special instance of ``ModelGraph`` that only allows calls to ``compile()``, ``predict()`` and ``build()``. Other calls to the ``ModelGraph`` instance are disabled.
30+
31+
To link a model use the following API:
32+
33+
.. code-block:: python
34+
35+
from hls4ml.converters import convert_from_keras_model, link_existing_project
36+
37+
model = convert_from_keras_model(keras_model, output_dir='/some/path/', ...)
38+
39+
# Generate the project files and write them to some path
40+
model.write()
41+
42+
# Later on, link this path to the Python runtime
43+
linked_model = link_existing_project('some/path/')
44+
linked_model.compile()
45+
linked_model.predict(...)
46+
linked_model.build(...)

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
api/concepts
1818
api/configuration
1919
api/command
20+
api/serialization
2021

2122
.. toctree::
2223
:hidden:

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,12 @@ def create_layer_class(self, layer_class):
143143
if issubclass(layer_class, cls):
144144
new_attrubutes.extend(attributes)
145145

146+
layer_cls_fqn = layer_class.__module__ + '.' + layer_class.__qualname__
147+
146148
return type(
147-
self.name + layer_class.__name__, (layer_class,), {'_expected_attributes': new_attrubutes, '_wrapped': True}
149+
self.name + layer_class.__name__,
150+
(layer_class,),
151+
{'_expected_attributes': new_attrubutes, '_wrapped': layer_cls_fqn},
148152
)
149153

150154
def compile(self, model):

hls4ml/backends/fpga/fpga_types.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(self, type_map, prefix):
103103
def convert(self, precision_type):
104104
type_cls = type(precision_type)
105105
type_cls_name = type_cls.__name__
106+
type_cls_fqn = type_cls.__module__ + '.' + type_cls.__qualname__
106107

107108
# If the type is already converted, do nothing
108109
if type_cls_name.startswith(self.prefix):
@@ -111,7 +112,9 @@ def convert(self, precision_type):
111112
definition_cls = self.type_map.get(type_cls, None)
112113

113114
if definition_cls is not None:
114-
precision_type.__class__ = type(self.prefix + type_cls_name, (type_cls, definition_cls), {})
115+
precision_type.__class__ = type(
116+
self.prefix + type_cls_name, (type_cls, definition_cls), {'_wrapped': type_cls_fqn}
117+
)
115118
return precision_type
116119
else:
117120
raise Exception(f'Cannot convert precision type to {self.prefix}: {precision_type.__class__.__name__}')
@@ -206,6 +209,7 @@ def __init__(self, precision_converter):
206209
def convert(self, atype):
207210
type_cls = type(atype)
208211
type_cls_name = type_cls.__name__
212+
type_cls_fqn = type_cls.__module__ + '.' + type_cls.__qualname__
209213

210214
# If the type is already converted, do nothing
211215
if type_cls_name.startswith('HLS'):
@@ -214,7 +218,7 @@ def convert(self, atype):
214218
conversion_cls = self.type_map.get(type_cls, None)
215219

216220
if conversion_cls is not None:
217-
atype.__class__ = type('HLS' + type_cls_name, (type_cls, conversion_cls), {})
221+
atype.__class__ = type('HLS' + type_cls_name, (type_cls, conversion_cls), {'_wrapped': type_cls_fqn})
218222
atype.convert_precision(self.precision_converter)
219223
return atype
220224
else:
@@ -246,8 +250,11 @@ def convert(self, tensor_var, pragma='partition'):
246250

247251
tensor_var.pragma = pragma
248252
tensor_var.type = self.type_converter.convert(tensor_var.type)
253+
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__
249254

250-
tensor_var.__class__ = type(self.prefix + 'ArrayVariable', (type(tensor_var), self.definition_cls), {})
255+
tensor_var.__class__ = type(
256+
self.prefix + 'ArrayVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
257+
)
251258
return tensor_var
252259

253260

@@ -273,8 +280,11 @@ def convert(self, tensor_var, pragma='partition', struct_name=None):
273280
tensor_var.struct_name = str(struct_name)
274281
tensor_var.member_name = tensor_var.name
275282
tensor_var.name = tensor_var.struct_name + '.' + tensor_var.member_name
283+
type_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__
276284

277-
tensor_var.__class__ = type(self.prefix + 'StructMemberVariable', (type(tensor_var), self.definition_cls), {})
285+
tensor_var.__class__ = type(
286+
self.prefix + 'StructMemberVariable', (type(tensor_var), self.definition_cls), {'_wrapped': type_cls_fqn}
287+
)
278288
return tensor_var
279289

280290

@@ -299,8 +309,11 @@ def convert(self, tensor_var, n_pack=1, depth=0):
299309
tensor_var.type = self.type_converter.convert(
300310
PackedType(tensor_var.type.name, tensor_var.type.precision, tensor_var.shape[-1], n_pack)
301311
)
312+
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__
302313

303-
tensor_var.__class__ = type(self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {})
314+
tensor_var.__class__ = type(
315+
self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
316+
)
304317
return tensor_var
305318

306319

@@ -318,8 +331,11 @@ def convert(self, tensor_var, n_pack=1, depth=0):
318331
tensor_var.type = self.type_converter.convert(
319332
PackedType(tensor_var.type.name, tensor_var.type.precision, tensor_var.input_var.shape[-1], n_pack)
320333
)
334+
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__
321335

322-
tensor_var.__class__ = type(self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {})
336+
tensor_var.__class__ = type(
337+
self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
338+
)
323339
return tensor_var
324340

325341

@@ -344,8 +360,11 @@ def convert(self, weight_var):
344360
weight_var.weight_class = weight_var.__class__.__name__
345361
weight_var.storage = 'register'
346362
weight_var.type = self.type_converter.convert(weight_var.type)
363+
tensor_cls_fqn = weight_var.__class__.__module__ + '.' + weight_var.__class__.__qualname__
347364

348-
weight_var.__class__ = type('StaticWeightVariable', (type(weight_var), StaticWeightVariableDefinition), {})
365+
weight_var.__class__ = type(
366+
'StaticWeightVariable', (type(weight_var), StaticWeightVariableDefinition), {'_wrapped': tensor_cls_fqn}
367+
)
349368
return weight_var
350369

351370

hls4ml/backends/oneapi/oneapi_types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ def convert(self, tensor_var, pragma='', depth=0, n_pack=1):
143143
# pipe_name and pipe_id are only used for io_stream and interface variables in io_parallel
144144
tensor_var.pipe_name = f'{convert_to_pascal_case(tensor_var.name)}Pipe'
145145
tensor_var.pipe_id = f'{convert_to_pascal_case(tensor_var.name)}PipeID'
146+
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__
146147

147-
tensor_var.__class__ = type(self.prefix + 'AggregateArrayVariable', (type(tensor_var), self.definition_cls), {})
148+
tensor_var.__class__ = type(
149+
self.prefix + 'AggregateArrayVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
150+
)
148151
return tensor_var
149152

150153

@@ -255,9 +258,12 @@ def convert(self, weight_var):
255258
weight_var.type = self.type_converter.convert(
256259
PackedType(weight_var.name + '_t', weight_var.type.precision, weight_var.data_length, 1)
257260
)
261+
weight_cls_fqn = weight_var.__class__.__module__ + '.' + weight_var.__class__.__qualname__
258262

259263
weight_var.__class__ = type(
260-
'OneAPIStaticWeightVariable', (type(weight_var), OneAPIStaticWeightVariableDefinition), {}
264+
'OneAPIStaticWeightVariable',
265+
(type(weight_var), OneAPIStaticWeightVariableDefinition),
266+
{'_wrapped': weight_cls_fqn},
261267
)
262268
return weight_var
263269

hls4ml/converters/__init__.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from hls4ml.model import ModelGraph
2121
from hls4ml.utils.config import create_config
2222
from hls4ml.utils.dependency import requires
23+
from hls4ml.utils.link import FilesystemModelGraph
24+
from hls4ml.utils.serialization import deserialize_model
2325
from hls4ml.utils.symbolic_utils import LUTFunction
2426

2527
# ----------Layer handling register----------#
@@ -464,6 +466,42 @@ def convert_from_symbolic_expression(
464466

465467
config['HLSConfig'] = {'Model': {'Precision': precision, 'ReuseFactor': 1}}
466468

467-
hls_model = ModelGraph(config, layer_list)
469+
hls_model = ModelGraph.from_layer_list(config, layer_list)
468470

469471
return hls_model
472+
473+
474+
def link_existing_project(project_dir):
475+
"""Create a stripped-down ModelGraph from an existing project previously generated by hls4ml.
476+
477+
The returned ModelGraph will only allow compile(), predict() and build() functions to be invoked.
478+
479+
Args:
480+
project_dir (str): Path to the existing HLS project previously generated with hls4ml.
481+
482+
Returns:
483+
FilesystemModelGraph: hls4ml model.
484+
"""
485+
return FilesystemModelGraph(project_dir)
486+
487+
488+
def load_saved_model(file_path, output_dir=None):
489+
"""
490+
Loads an hls4ml model from a compressed file format (.fml).
491+
492+
See `hls4ml.utils.serialization.deserialize_model` for more details.
493+
494+
Args:
495+
file_path (str or pathlib.Path): The path to the serialized model file (.fml).
496+
output_dir (str or pathlib.Path, optional): The directory where extracted
497+
testbench data files will be saved. If not specified, the files will
498+
be restored to the same directory as the `.fml` file.
499+
500+
Returns:
501+
ModelGraph: The deserialized hls4ml model.
502+
503+
Raises:
504+
FileNotFoundError: If the specified `.fml` file does not exist.
505+
OSError: If an I/O error occurs during extraction or file operations.
506+
"""
507+
return deserialize_model(file_path, output_dir=output_dir)

hls4ml/converters/keras_to_hls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,5 +360,5 @@ def keras_to_hls(config):
360360
model_arch, reader = get_model_arch(config)
361361
layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader)
362362
print('Creating HLS model')
363-
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)
363+
hls_model = ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
364364
return hls_model

hls4ml/converters/onnx_to_hls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,5 +292,5 @@ def onnx_to_hls(config):
292292
#################
293293

294294
print('Creating HLS model')
295-
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)
295+
hls_model = ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
296296
return hls_model

hls4ml/converters/pytorch_to_hls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,5 +426,5 @@ def parse_pytorch_model(config, verbose=True):
426426
def pytorch_to_hls(config):
427427
layer_list, input_layers, output_layers = parse_pytorch_model(config)
428428
print('Creating HLS model')
429-
hls_model = ModelGraph(config, layer_list, inputs=input_layers, outputs=output_layers)
429+
hls_model = ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers)
430430
return hls_model

0 commit comments

Comments
 (0)