Skip to content

Infrastructure for saving/loading hls4ml models #1158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

vloncar
Copy link
Contributor

@vloncar vloncar commented Dec 22, 2024

Description

Adds the ability to save and load hls4ml models (serialize/deserialize them). Given a ModelGraph, this will serialize it in a a single file which can be loaded at a later stage. The saved model doesn't depend on the original Keras/PyTorch/ONNX model in any way.

The feature is in part inspired by Keras' model saving feature. The main format used for serialization is JSON, all objects save their state in dictionaries which are serialized into JSON. Assuming disk space is not a problem, generated JSON is nicely formatted during writing to file. No objects are pickled, as this is way too unsafe. The numpy arrays (weights) are saved in npz format. We save model graph (list of layers), the model information and config into separate files. This (along with some versioning information is packaged into a .fml file, which is just a .tar.gz with a different name.

Internally, this works by adding a few methods to types, quantizers, layers and model graph itself. The interface is defined by the Serializable class. Classes would typically implement serialize_state() method, which should return a dictionary of current state of the object. Additionally, there's also a serialize_class_name() which is needed to know what instance are we saving, but most classes won't need to deal with this. Deserialization is done with a class method deserialize(). To support this feature some restructuring had to be done. ModelGraph has been intended to be created only with a layer list from a converter, which is not compatible with (de)serialization, so it was split into initialization of empty ModelGraph and conversion of layer list from converters to Layer objects. Furthermore, Layer's initialization has to be skipped, as we're basically restoring a state post-initialization. Types and quantizers are more straightforward to save/load. Loaded model should be indistinguishable from the original, but there may be some corner cases of some hacks of internal state of layers (or partially optimized models) not working on loaded models, we can catch these over time. But for "final" models (one you're happy enough with to call write()/compile()/build() on) saving/loading should always work.

One somewhat ugly part in the current implementation is that due to the creation of dynamic wrapper classes, we cannot directly deserialize to them, instead we create the original types, and have to run <backend>:specific_types optimizer to truly get an object that is identical to the original one. Running that optimizer for a given backend looks a bit hacky, but is ok for now since all backends have an optimizer by that name.

Type of change

  • New feature (non-breaking change which adds functionality)

Tests

Included is a test in test_serialization.py that tests saving/loading QKeras and QONNX models. These cover serialization of most types and quantizers that can appear in a model, but obviously not all possible layers. Maybe a more thorough test would be to extend most existing tests to save and load a model and then continue working with a loaded model. But I'll leave that to a future PR.

Checklist

I've done all the usual checks prior to opening this PR.

@vloncar vloncar added the please test Trigger testing by creating local PR branch label Dec 22, 2024
@bo3z bo3z added this to the v1.1.0 milestone Jan 7, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jan 7, 2025
@jmitrevs
Copy link
Contributor

Even though the running time exceeded the limits, there were failures in test_serialization beforehand.

@JanFSchulte
Copy link
Contributor

Even though the running time exceeded the limits, there were failures in test_serialization beforehand.

Indeed. I am just rerunning the tests to see which part is taking so long, and the QONNX test runs really fast, so it is the serialization test itself that is very slow.

@jmitrevs
Copy link
Contributor

Running locally on my linux machine the serialization tests are pretty quick (1-2min), but I get:

FAILED test_serialization.py::test_qkeras_model[io_stream-oneAPI] - subprocess.CalledProcessError: Command 'make lib' returned non-zero exit status 2.
FAILED test_serialization.py::test_qonnx_model[oneAPI] - subprocess.CalledProcessError: Command 'make lib' returned non-zero exit status 2.
FAILED test_serialization.py::test_qkeras_model[io_parallel-oneAPI] - subprocess.CalledProcessError: Command 'make lib' returned non-zero exit status 2.

The exact failure is:

icpx: error: fpga compiler command failed with exit code 14 (use -v to see invocation)

I will investigate. One thing that is kind of annoying is that I get more failures on my mac since ap_math doesn't really support clang:

firmware/ac_math/include/ac_math/ac_pow_pwl.h:300:70: error: typedef 'pit_t' cannot be referenced with a class specifier
  300 |     typedef class comp_pii_exp<W, I, S, n_frac_bits + extra_f_bits>::pit_t input_inter_type;
      |                                                                      ^

But that's unrelated to this PR.

@vloncar vloncar added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Apr 1, 2025
@calad0i calad0i mentioned this pull request Apr 5, 2025
7 tasks
Copy link
Contributor

@bo3z bo3z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of reviews of the saving/loading infrastructure of hls4ml models. Generally, looks very good - most of the comments are minor with some questions for my better understanding. Two high-level questions / comments, for here as well:

  1. Would it be better to name the methods save_model and load_model? This is closer to what Keras does and maybe a bit clearer to most users?
  2. At what stage can a model be saved? After conversion but before calling .compile(...) or also after compilation? Most of the examples/tests I've seen in this PR are before compilation. Generally, it would be useful to be able to save/load the model after compilation (e.g. I converted the model on my local machine, compiled it to verify accuracy with predict(...) and then saved it, to synthesize it on a remote node). This is trickier though and can open up non-clear issues....e.g, if the CPU ISA is different between nodes the comiled hls4ml library is useless. But also useful, since compilation can sometimes take long (especially with lots of code generation).

np.testing.assert_equal(y_original, y_clone)


@pytest.mark.parametrize('backend', ['Vitis']) # Disabling OneAPI for now excessive run time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneAPI is enabled in the above QKeras test. Is it faster with QKeras than with QKeras or should the above one be also commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually due to a bug in oneAPI that causes very long running time at 100% CPU usage due to context switches between threads. Reported to Intel.

test_root_path = Path(__file__).parent
example_model_path = (test_root_path / '../../example-models').resolve()


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also have a PyTorch test for serialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wish, but will be the same result. Ultimately we're serializing the ModelGraph, where it originated from should not matter much. I added QONNX just to be sure more advanced models are working

config.pop('OnnxModel', None)
config.pop('PytorchModel', None)

# Much of this may not be needed and is already in 'config' dict but is kept here to be sure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit error-prone, because if there's a new attribute added to HLS config and it's not reflected in these two functions, it could fail later on. I guess there are three ways to handle this:

  • As the comment suggests, assume most (all?) of the variables are reflected in config, so the others are simply fail-safe mechanisms... do we know which of this variables aren't directly stored in config?
  • Find a way to implicitly iterate through all the internal member variables of HLSConfig and store them to state. This seems like something that's doable in Python
  • Ignore for now, as it doesn't seem the HLS config will change a lot in the (near) future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another one of those cases of "if only we had better config infrastructure" 😄

The problem we have is that there's the initial config dict, and the current config class. Then there may be code that uses one or the other, so we need to save both. Your second solution is error prone too, as it would assume we know how to handle all possible values of internal members so we can save them.

In another project we're investigating Pydantic for for config schemas, maybe we can apply that to hls4ml in the future. For now I feel we could maybe make it clear to future developers that if they add new internal members they need to also follow up with serialization functions.

self._applied_flows = []
@classmethod
def from_layer_list(cls, config_dict, layer_list, inputs=None, outputs=None):
def _find_output_variable_names(layer_list, layer_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a personal preference - but I am not the biggest fan of functions inside of functions unless absolutely necessary. Is there a reason this function got moved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No other function uses it and it doesn't depend on self. If it was separate it would also have to be a classmethod, but it felt weird to have a class method intended only for internal use. I actually considered moving it to a separate utility file, but that also felt like too much given that it really is used only once.

@classmethod
def deserialize(cls, state):
raise Exception(
'{cls.__name__} is not intended to be deserialized directly. Use {cls.__name__}.from_saved_state instead.'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this missing a string formatter? f'...{}...'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed it does

@@ -76,6 +90,10 @@ def __call__(self, data):
ones = np.ones_like(data)
return np.where(data > 0.5, ones, np.where(data <= -0.5, -ones, zeros))

def serialize_state(self):
state = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but not sure if I am missing something obvious here: the BinaryQuantizer has a state but the TernaryQuantizer doesn't? In my understanding of the serialization, neither should have the state? Or if yes, both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weird one... The "state" here is simply the choice of the number of bits used (1 or 2). For ternary we will always use 2 bits so there's no state to save. State is something you will need to pass to __init__(...) to recreate the object.

@@ -588,5 +731,16 @@ def __init__(self, code):
def __str__(self):
return str(self.code)

def serialize_class_name(self):
cls = self.__class__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this function re-implemented here? To me it seems the same as the one in the Serializable class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's different. The wrangling of the name is needed because types are created dynamically (belonging to a backend). Source isn't one of them

from .._version import version


def serialize_model(model, file_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a short docstring to this function? Could be a copy of the docs, but still nice to have since it's a user-facing function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll do that

@bo3z bo3z added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Apr 7, 2025
@bo3z bo3z modified the milestones: v1.1.0, v1.2.0 Apr 7, 2025
@vloncar
Copy link
Contributor Author

vloncar commented Apr 8, 2025

For the general comments:

  1. Can be done. I actually tried not to copy Keras naming convention, but if that's what people prefer, I'm fine with renaming the function
  2. You can call it afterwards. We save the IR, not what the IR ultimately generates (writes as a project). But we ensure that what is written is the same whether the model was serialized or not. write() doesn't change the internal state, so you're safe to call it many times at whichever point you want. compile() will call write(), then compile to create the .so and link it. This .so and the subsequent link to python runtime cannot be saved as they are not portable at all. Even on the same CPU type, but different OS family. For example on a desktop/laptop people would use Ubuntu or something else modern, and then move it to server for synthesis which is likely to run RHEL or clone. The compiled .so won't work and will need recompilation. Luckily you can be sure that when you call write(), compile() or build() you that first stage of writing the project to disk will result in the the same thing as the original model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants