-
Notifications
You must be signed in to change notification settings - Fork 446
Infrastructure for saving/loading hls4ml models #1158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Even though the running time exceeded the limits, there were failures in test_serialization beforehand. |
Indeed. I am just rerunning the tests to see which part is taking so long, and the QONNX test runs really fast, so it is the serialization test itself that is very slow. |
Running locally on my linux machine the serialization tests are pretty quick (1-2min), but I get:
The exact failure is:
I will investigate. One thing that is kind of annoying is that I get more failures on my mac since ap_math doesn't really support clang:
But that's unrelated to this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round of reviews of the saving/loading infrastructure of hls4ml models. Generally, looks very good - most of the comments are minor with some questions for my better understanding. Two high-level questions / comments, for here as well:
- Would it be better to name the methods save_model and load_model? This is closer to what Keras does and maybe a bit clearer to most users?
- At what stage can a model be saved? After conversion but before calling .compile(...) or also after compilation? Most of the examples/tests I've seen in this PR are before compilation. Generally, it would be useful to be able to save/load the model after compilation (e.g. I converted the model on my local machine, compiled it to verify accuracy with predict(...) and then saved it, to synthesize it on a remote node). This is trickier though and can open up non-clear issues....e.g, if the CPU ISA is different between nodes the comiled hls4ml library is useless. But also useful, since compilation can sometimes take long (especially with lots of code generation).
np.testing.assert_equal(y_original, y_clone) | ||
|
||
|
||
@pytest.mark.parametrize('backend', ['Vitis']) # Disabling OneAPI for now excessive run time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneAPI is enabled in the above QKeras test. Is it faster with QKeras than with QKeras or should the above one be also commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually due to a bug in oneAPI that causes very long running time at 100% CPU usage due to context switches between threads. Reported to Intel.
test_root_path = Path(__file__).parent | ||
example_model_path = (test_root_path / '../../example-models').resolve() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also have a PyTorch test for serialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you wish, but will be the same result. Ultimately we're serializing the ModelGraph, where it originated from should not matter much. I added QONNX just to be sure more advanced models are working
config.pop('OnnxModel', None) | ||
config.pop('PytorchModel', None) | ||
|
||
# Much of this may not be needed and is already in 'config' dict but is kept here to be sure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit error-prone, because if there's a new attribute added to HLS config and it's not reflected in these two functions, it could fail later on. I guess there are three ways to handle this:
- As the comment suggests, assume most (all?) of the variables are reflected in config, so the others are simply fail-safe mechanisms... do we know which of this variables aren't directly stored in config?
- Find a way to implicitly iterate through all the internal member variables of HLSConfig and store them to state. This seems like something that's doable in Python
- Ignore for now, as it doesn't seem the HLS config will change a lot in the (near) future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another one of those cases of "if only we had better config infrastructure" 😄
The problem we have is that there's the initial config dict, and the current config class. Then there may be code that uses one or the other, so we need to save both. Your second solution is error prone too, as it would assume we know how to handle all possible values of internal members so we can save them.
In another project we're investigating Pydantic for for config schemas, maybe we can apply that to hls4ml in the future. For now I feel we could maybe make it clear to future developers that if they add new internal members they need to also follow up with serialization functions.
self._applied_flows = [] | ||
@classmethod | ||
def from_layer_list(cls, config_dict, layer_list, inputs=None, outputs=None): | ||
def _find_output_variable_names(layer_list, layer_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a personal preference - but I am not the biggest fan of functions inside of functions unless absolutely necessary. Is there a reason this function got moved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No other function uses it and it doesn't depend on self
. If it was separate it would also have to be a classmethod
, but it felt weird to have a class method intended only for internal use. I actually considered moving it to a separate utility file, but that also felt like too much given that it really is used only once.
hls4ml/model/graph.py
Outdated
@classmethod | ||
def deserialize(cls, state): | ||
raise Exception( | ||
'{cls.__name__} is not intended to be deserialized directly. Use {cls.__name__}.from_saved_state instead.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this missing a string formatter? f'...{}...'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed it does
@@ -76,6 +90,10 @@ def __call__(self, data): | |||
ones = np.ones_like(data) | |||
return np.where(data > 0.5, ones, np.where(data <= -0.5, -ones, zeros)) | |||
|
|||
def serialize_state(self): | |||
state = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor, but not sure if I am missing something obvious here: the BinaryQuantizer has a state but the TernaryQuantizer doesn't? In my understanding of the serialization, neither should have the state? Or if yes, both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a weird one... The "state" here is simply the choice of the number of bits used (1 or 2). For ternary we will always use 2 bits so there's no state to save. State is something you will need to pass to __init__(...)
to recreate the object.
@@ -588,5 +731,16 @@ def __init__(self, code): | |||
def __str__(self): | |||
return str(self.code) | |||
|
|||
def serialize_class_name(self): | |||
cls = self.__class__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this function re-implemented here? To me it seems the same as the one in the Serializable class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's different. The wrangling of the name is needed because types are created dynamically (belonging to a backend). Source
isn't one of them
from .._version import version | ||
|
||
|
||
def serialize_model(model, file_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a short docstring to this function? Could be a copy of the docs, but still nice to have since it's a user-facing function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I'll do that
For the general comments:
|
Description
Adds the ability to save and load hls4ml models (serialize/deserialize them). Given a
ModelGraph
, this will serialize it in a a single file which can be loaded at a later stage. The saved model doesn't depend on the original Keras/PyTorch/ONNX model in any way.The feature is in part inspired by Keras' model saving feature. The main format used for serialization is JSON, all objects save their state in dictionaries which are serialized into JSON. Assuming disk space is not a problem, generated JSON is nicely formatted during writing to file. No objects are pickled, as this is way too unsafe. The numpy arrays (weights) are saved in npz format. We save model graph (list of layers), the model information and config into separate files. This (along with some versioning information is packaged into a
.fml
file, which is just a.tar.gz
with a different name.Internally, this works by adding a few methods to types, quantizers, layers and model graph itself. The interface is defined by the
Serializable
class. Classes would typically implementserialize_state()
method, which should return a dictionary of current state of the object. Additionally, there's also aserialize_class_name()
which is needed to know what instance are we saving, but most classes won't need to deal with this. Deserialization is done with a class methoddeserialize()
. To support this feature some restructuring had to be done.ModelGraph
has been intended to be created only with a layer list from a converter, which is not compatible with (de)serialization, so it was split into initialization of empty ModelGraph and conversion of layer list from converters toLayer
objects. Furthermore,Layer
's initialization has to be skipped, as we're basically restoring a state post-initialization. Types and quantizers are more straightforward to save/load. Loaded model should be indistinguishable from the original, but there may be some corner cases of some hacks of internal state of layers (or partially optimized models) not working on loaded models, we can catch these over time. But for "final" models (one you're happy enough with to callwrite()
/compile()
/build()
on) saving/loading should always work.One somewhat ugly part in the current implementation is that due to the creation of dynamic wrapper classes, we cannot directly deserialize to them, instead we create the original types, and have to run
<backend>:specific_types
optimizer to truly get an object that is identical to the original one. Running that optimizer for a given backend looks a bit hacky, but is ok for now since all backends have an optimizer by that name.Type of change
Tests
Included is a test in
test_serialization.py
that tests saving/loading QKeras and QONNX models. These cover serialization of most types and quantizers that can appear in a model, but obviously not all possible layers. Maybe a more thorough test would be to extend most existing tests to save and load a model and then continue working with a loaded model. But I'll leave that to a future PR.Checklist
I've done all the usual checks prior to opening this PR.