diff --git a/psipy/core/io/saveable_sklearn.py b/psipy/core/io/saveable_sklearn.py index 1bde6b0..49dbbf1 100644 --- a/psipy/core/io/saveable_sklearn.py +++ b/psipy/core/io/saveable_sklearn.py @@ -13,6 +13,7 @@ SaveableSklearnMixin """ +import numpy as np from typing import Any, Dict @@ -53,7 +54,12 @@ def get_config(self) -> Dict[str, Any]: def _save(self, zipfile: MemoryZipFile) -> MemoryZipFile: zipfile.add("config.json", self.get_config()) - zipfile.add_mixed_dict("state", self.__getstate__()) + state_dict = self.__getstate__() + for key, arr in state_dict.items(): + if isinstance(arr, np.ndarray) and arr.dtype == object: + if all(isinstance(x, str) for x in arr.flat): + state_dict[key] = arr.astype(str) + zipfile.add_mixed_dict("state", state_dict) return zipfile @classmethod