diff --git a/mipcandy/layer.py b/mipcandy/layer.py index 6393ea8..dc6a6e3 100644 --- a/mipcandy/layer.py +++ b/mipcandy/layer.py @@ -105,6 +105,9 @@ def build_network(self, example_shape: tuple[int, ...]) -> nn.Module: raise NotImplementedError def build_network_from_checkpoint(self, example_shape: tuple[int, ...], checkpoint: Mapping[str, Any]) -> nn.Module: + """ + Internally exposed interface for overriding. Use `load_model()` instead. + """ network = self.build_network(example_shape) network.load_state_dict(checkpoint) return network diff --git a/mipcandy/training.py b/mipcandy/training.py index 2a62243..fecdfc2 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -60,7 +60,7 @@ class TrainerTracker(object): class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta): def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], - validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, + validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True, device: torch.device | str = "cpu", console: Console = Console()) -> None: WithPaddingModule.__init__(self, device) WithNetwork.__init__(self, device) @@ -69,6 +69,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._experiment_id: str = "tbd" self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader + self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered self._console: Console = console self._metrics: dict[str, list[float]] = {} self._epoch_metrics: dict[str, list[float]] = {} @@ -80,9 +81,11 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker, **training_arguments) -> None: - torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pt") - torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pt") - torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pt") + if self._unrecoverable: + return + torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth") + torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth") + torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth") torch.save(tracker, f"{self.experiment_folder()}/tracker.pt") with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f: dump(training_arguments, f) @@ -98,13 +101,27 @@ def load_metrics(self) -> dict[str, list[float]]: df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") return {column: df[column].astype(float).tolist() for column in df.columns} + def load_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: + toolbox = self._build_toolbox(num_epochs, example_shape, model=self.load_model( + example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") + )) + toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth")) + toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth")) + toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth")) + return toolbox + def recover_from(self, experiment_id: str) -> Self: self._experiment_id = experiment_id + if not exists(self.experiment_folder()): + raise FileNotFoundError(f"Experiment folder {self.experiment_folder()} not found") self._metrics = self.load_metrics() self._tracker = self.load_tracker() + self._unrecoverable = None return self def continue_training(self, num_epochs: int) -> None: + if not self.recovery(): + raise RuntimeError("Must call `recover_from()` before continuing training") self.train(num_epochs, **self.load_training_arguments()) # Getters @@ -141,6 +158,9 @@ def tracker(self) -> TrainerTracker: def initialized(self) -> bool: return self._experiment_id != "tbd" + def recovery(self) -> bool: + return self._unrecoverable is None + def experiment_folder(self) -> str: return f"{self._trainer_folder}/{self._trainer_variant}/{self._experiment_id}" @@ -189,6 +209,9 @@ def allocate_experiment_folder(self) -> str: return self.experiment_folder() if self.initialized() else self._allocate_experiment_folder() def init_experiment(self) -> None: + if self.recovery(): + self.log(f"Training progress recovered from {self._experiment_id} from epoch {self._tracker.epoch}") + return if self.initialized(): raise RuntimeError("Experiment already initialized") makedirs(self._trainer_folder, exist_ok=True) @@ -307,13 +330,18 @@ def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim. def build_criterion(self) -> nn.Module: raise NotImplementedError - def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: - model = self.load_model(example_shape) + def _build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...], *, + model: nn.Module | None = None) -> TrainerToolbox: + if not model: + model = self.load_model(example_shape) optimizer = self.build_optimizer(model.parameters()) scheduler = self.build_scheduler(optimizer, num_epochs) criterion = self.build_criterion().to(self._device) return TrainerToolbox(model, optimizer, scheduler, criterion) + def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: + return self._build_toolbox(num_epochs, example_shape) + # Training methods @abstractmethod @@ -365,7 +393,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em example_input = padding_module(example_input) example_shape = tuple(example_input.shape[1:]) self.log(f"Example input shape: {example_shape}") - toolbox = self.build_toolbox(num_epochs, example_shape) + toolbox = self.load_toolbox(num_epochs, example_shape) if self.recovery() else self.build_toolbox( + num_epochs, example_shape) model_name = toolbox.model.__class__.__name__ sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device) self.log(f"Model: {model_name}")