From d42f29be274471ffc361b5ba608f891fc4d743a5 Mon Sep 17 00:00:00 2001 From: ATATC Date: Fri, 5 Dec 2025 22:16:29 -0500 Subject: [PATCH 1/8] Added `recoverable` flag to `Trainer` for recovery control. (#106) --- mipcandy/training.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 2a62243..87a09c3 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -60,7 +60,7 @@ class TrainerTracker(object): class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta): def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], - validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, + validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True, device: torch.device | str = "cpu", console: Console = Console()) -> None: WithPaddingModule.__init__(self, device) WithNetwork.__init__(self, device) @@ -69,6 +69,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._experiment_id: str = "tbd" self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader + self._recoverable: bool = recoverable self._console: Console = console self._metrics: dict[str, list[float]] = {} self._epoch_metrics: dict[str, list[float]] = {} @@ -80,6 +81,8 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker, **training_arguments) -> None: + if not self._recoverable: + return torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pt") torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pt") torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pt") @@ -102,9 +105,11 @@ def recover_from(self, experiment_id: str) -> Self: self._experiment_id = experiment_id self._metrics = self.load_metrics() self._tracker = self.load_tracker() + self._recoverable = True return self def continue_training(self, num_epochs: int) -> None: + self._recoverable = True self.train(num_epochs, **self.load_training_arguments()) # Getters From 453c3bae8af4b212d48985a83fdf26133036f948 Mon Sep 17 00:00:00 2001 From: ATATC Date: Fri, 5 Dec 2025 22:43:53 -0500 Subject: [PATCH 2/8] Renamed `recoverable` flag to `unrecoverable` in `Trainer` for `None` value logit and added `load_toolbox()`. (#106) --- mipcandy/training.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 87a09c3..d4dad42 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -69,7 +69,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._experiment_id: str = "tbd" self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader - self._recoverable: bool = recoverable + self._unrecoverable: bool = not recoverable # None if the trainer is recovered self._console: Console = console self._metrics: dict[str, list[float]] = {} self._epoch_metrics: dict[str, list[float]] = {} @@ -81,7 +81,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker, **training_arguments) -> None: - if not self._recoverable: + if self._unrecoverable: return torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pt") torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pt") @@ -101,15 +101,25 @@ def load_metrics(self) -> dict[str, list[float]]: df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") return {column: df[column].astype(float).tolist() for column in df.columns} + def load_toolbox(self, example_shape: tuple[int, ...]) -> TrainerToolbox: + return TrainerToolbox( + self.build_network_from_checkpoint( + example_shape, torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") + ), + torch.load(f"{self.experiment_folder()}/optimizer.pt", weights_only=False), + torch.load(f"{self.experiment_folder()}/scheduler.pt", weights_only=False), + torch.load(f"{self.experiment_folder()}/criterion.pt", weights_only=False) + ) + def recover_from(self, experiment_id: str) -> Self: self._experiment_id = experiment_id self._metrics = self.load_metrics() self._tracker = self.load_tracker() - self._recoverable = True + self._unrecoverable = None return self def continue_training(self, num_epochs: int) -> None: - self._recoverable = True + self._unrecoverable = None self.train(num_epochs, **self.load_training_arguments()) # Getters @@ -370,7 +380,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em example_input = padding_module(example_input) example_shape = tuple(example_input.shape[1:]) self.log(f"Example input shape: {example_shape}") - toolbox = self.build_toolbox(num_epochs, example_shape) + toolbox = self.load_toolbox(example_shape) if self._unrecoverable is None else self.build_toolbox(num_epochs, + example_shape) model_name = toolbox.model.__class__.__name__ sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device) self.log(f"Model: {model_name}") From a98ff0fb0dc20986c86bc62b1b91c50749758fbe Mon Sep 17 00:00:00 2001 From: ATATC Date: Fri, 5 Dec 2025 23:33:08 -0500 Subject: [PATCH 3/8] Added `recovery` to `Trainer` for recovery state checks and updated related logic. (#106) --- mipcandy/training.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index d4dad42..01ac3f3 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -69,7 +69,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._experiment_id: str = "tbd" self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader - self._unrecoverable: bool = not recoverable # None if the trainer is recovered + self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered self._console: Console = console self._metrics: dict[str, list[float]] = {} self._epoch_metrics: dict[str, list[float]] = {} @@ -119,7 +119,8 @@ def recover_from(self, experiment_id: str) -> Self: return self def continue_training(self, num_epochs: int) -> None: - self._unrecoverable = None + if not self.recovery(): + raise RuntimeError("Must call `recover_from()` before continuing training") self.train(num_epochs, **self.load_training_arguments()) # Getters @@ -156,6 +157,9 @@ def tracker(self) -> TrainerTracker: def initialized(self) -> bool: return self._experiment_id != "tbd" + def recovery(self) -> bool: + return self._unrecoverable is None + def experiment_folder(self) -> str: return f"{self._trainer_folder}/{self._trainer_variant}/{self._experiment_id}" @@ -204,6 +208,9 @@ def allocate_experiment_folder(self) -> str: return self.experiment_folder() if self.initialized() else self._allocate_experiment_folder() def init_experiment(self) -> None: + if self.recovery(): + self.log(f"Training progress recovered from {self._experiment_id} from epoch {self._tracker.epoch}") + return if self.initialized(): raise RuntimeError("Experiment already initialized") makedirs(self._trainer_folder, exist_ok=True) @@ -380,8 +387,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em example_input = padding_module(example_input) example_shape = tuple(example_input.shape[1:]) self.log(f"Example input shape: {example_shape}") - toolbox = self.load_toolbox(example_shape) if self._unrecoverable is None else self.build_toolbox(num_epochs, - example_shape) + toolbox = self.load_toolbox(example_shape) if self.recovery() else self.build_toolbox(num_epochs, example_shape) model_name = toolbox.model.__class__.__name__ sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device) self.log(f"Model: {model_name}") From 1e00030fdf72416683c15018cfdce9e2a1d727af Mon Sep 17 00:00:00 2001 From: ATATC Date: Sat, 6 Dec 2025 13:02:59 -0500 Subject: [PATCH 4/8] Refactored `build_network_from_checkpoint()` usage; switched to `load_model()`. (#106) --- mipcandy/layer.py | 3 +++ mipcandy/training.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mipcandy/layer.py b/mipcandy/layer.py index 6393ea8..dc6a6e3 100644 --- a/mipcandy/layer.py +++ b/mipcandy/layer.py @@ -105,6 +105,9 @@ def build_network(self, example_shape: tuple[int, ...]) -> nn.Module: raise NotImplementedError def build_network_from_checkpoint(self, example_shape: tuple[int, ...], checkpoint: Mapping[str, Any]) -> nn.Module: + """ + Internally exposed interface for overriding. Use `load_model()` instead. + """ network = self.build_network(example_shape) network.load_state_dict(checkpoint) return network diff --git a/mipcandy/training.py b/mipcandy/training.py index 01ac3f3..79794c1 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -103,8 +103,8 @@ def load_metrics(self) -> dict[str, list[float]]: def load_toolbox(self, example_shape: tuple[int, ...]) -> TrainerToolbox: return TrainerToolbox( - self.build_network_from_checkpoint( - example_shape, torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") + self.load_model( + example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") ), torch.load(f"{self.experiment_folder()}/optimizer.pt", weights_only=False), torch.load(f"{self.experiment_folder()}/scheduler.pt", weights_only=False), From 42ef8e65ad327a39e0740c9f733c8e7a2fc0abd8 Mon Sep 17 00:00:00 2001 From: ATATC Date: Sat, 6 Dec 2025 13:26:13 -0500 Subject: [PATCH 5/8] Switched `torch.save` and `torch.load` to use `state_dict()` for optimizer, scheduler, and criterion in `Trainer`. Updated `load_toolbox()` logic to align with this change. (#106) --- mipcandy/training.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 79794c1..4955dfd 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -83,9 +83,9 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer **training_arguments) -> None: if self._unrecoverable: return - torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pt") - torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pt") - torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pt") + torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth") + torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth") + torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth") torch.save(tracker, f"{self.experiment_folder()}/tracker.pt") with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f: dump(training_arguments, f) @@ -101,15 +101,15 @@ def load_metrics(self) -> dict[str, list[float]]: df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") return {column: df[column].astype(float).tolist() for column in df.columns} - def load_toolbox(self, example_shape: tuple[int, ...]) -> TrainerToolbox: - return TrainerToolbox( - self.load_model( - example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") - ), - torch.load(f"{self.experiment_folder()}/optimizer.pt", weights_only=False), - torch.load(f"{self.experiment_folder()}/scheduler.pt", weights_only=False), - torch.load(f"{self.experiment_folder()}/criterion.pt", weights_only=False) + def load_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: + toolbox = self.build_toolbox(num_epochs, example_shape) + toolbox.model = self.load_model( + example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") ) + toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth")) + toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth")) + toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth")) + return toolbox def recover_from(self, experiment_id: str) -> Self: self._experiment_id = experiment_id From ebaeccc75ebaa741513da4b5a4e701553591b0a6 Mon Sep 17 00:00:00 2001 From: ATATC Date: Sat, 6 Dec 2025 13:31:06 -0500 Subject: [PATCH 6/8] Raise `FileNotFoundError` in `recover_from()` if experiment folder is missing in `Trainer`. (#106) --- mipcandy/training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mipcandy/training.py b/mipcandy/training.py index 4955dfd..8a6294b 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -113,6 +113,8 @@ def load_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> Train def recover_from(self, experiment_id: str) -> Self: self._experiment_id = experiment_id + if not exists(self.experiment_folder()): + raise FileNotFoundError(f"Experiment folder {self.experiment_folder()} not found") self._metrics = self.load_metrics() self._tracker = self.load_tracker() self._unrecoverable = None From e006f0ae33c92ff6dc8f3fae9517f9ea611e1300 Mon Sep 17 00:00:00 2001 From: ATATC Date: Sat, 6 Dec 2025 13:32:39 -0500 Subject: [PATCH 7/8] Updated `load_toolbox()` call to include `num_epochs` during recovery checks in `Trainer`. (#106) --- mipcandy/training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 8a6294b..59bafc9 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -389,7 +389,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em example_input = padding_module(example_input) example_shape = tuple(example_input.shape[1:]) self.log(f"Example input shape: {example_shape}") - toolbox = self.load_toolbox(example_shape) if self.recovery() else self.build_toolbox(num_epochs, example_shape) + toolbox = self.load_toolbox(num_epochs, example_shape) if self.recovery() else self.build_toolbox( + num_epochs, example_shape) model_name = toolbox.model.__class__.__name__ sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device) self.log(f"Model: {model_name}") From 4f025a6adae9024b29ae9b9a516abf7743a0e4c6 Mon Sep 17 00:00:00 2001 From: ATATC Date: Sat, 6 Dec 2025 14:16:49 -0500 Subject: [PATCH 8/8] Refactored `build_toolbox()` by delegating implementation to the new `_build_toolbox()` helper, improving reusability and modularity in `Trainer`. (#106) --- mipcandy/training.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 59bafc9..fecdfc2 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -102,10 +102,9 @@ def load_metrics(self) -> dict[str, list[float]]: return {column: df[column].astype(float).tolist() for column in df.columns} def load_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: - toolbox = self.build_toolbox(num_epochs, example_shape) - toolbox.model = self.load_model( + toolbox = self._build_toolbox(num_epochs, example_shape, model=self.load_model( example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") - ) + )) toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth")) toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth")) toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth")) @@ -331,13 +330,18 @@ def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim. def build_criterion(self) -> nn.Module: raise NotImplementedError - def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: - model = self.load_model(example_shape) + def _build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...], *, + model: nn.Module | None = None) -> TrainerToolbox: + if not model: + model = self.load_model(example_shape) optimizer = self.build_optimizer(model.parameters()) scheduler = self.build_scheduler(optimizer, num_epochs) criterion = self.build_criterion().to(self._device) return TrainerToolbox(model, optimizer, scheduler, criterion) + def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: + return self._build_toolbox(num_epochs, example_shape) + # Training methods @abstractmethod