diff --git a/src/nnssl/training/nnsslTrainer/masked_image_modeling/SparkTrainer.py b/src/nnssl/training/nnsslTrainer/masked_image_modeling/SparkTrainer.py index 07425368..eff431f3 100644 --- a/src/nnssl/training/nnsslTrainer/masked_image_modeling/SparkTrainer.py +++ b/src/nnssl/training/nnsslTrainer/masked_image_modeling/SparkTrainer.py @@ -127,6 +127,7 @@ def save_checkpoint(self, filename: str, live_upload: bool = False) -> None: "current_epoch": self.current_epoch + 1, "init_args": self.my_init_kwargs, "trainer_name": self.__class__.__name__, + "nnssl_adaptation_plan": self.adaptation_plan.serialize(), } checkpoint = self._convert_numpy(checkpoint) torch.save(checkpoint, filename)