Skip to content
20 changes: 18 additions & 2 deletions mipcandy/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TrainerTracker(object):

class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *,
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True,
device: torch.device | str = "cpu", console: Console = Console()) -> None:
WithPaddingModule.__init__(self, device)
WithNetwork.__init__(self, device)
Expand All @@ -69,6 +69,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t
self._experiment_id: str = "tbd"
self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader
self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader
self._unrecoverable: bool = not recoverable # None if the trainer is recovered
self._console: Console = console
self._metrics: dict[str, list[float]] = {}
self._epoch_metrics: dict[str, list[float]] = {}
Expand All @@ -80,6 +81,8 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t

def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker,
**training_arguments) -> None:
if self._unrecoverable:
return
torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pt")
torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pt")
torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pt")
Expand All @@ -98,13 +101,25 @@ def load_metrics(self) -> dict[str, list[float]]:
df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch")
return {column: df[column].astype(float).tolist() for column in df.columns}

def load_toolbox(self, example_shape: tuple[int, ...]) -> TrainerToolbox:
return TrainerToolbox(
self.build_network_from_checkpoint(
example_shape, torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth")
),
torch.load(f"{self.experiment_folder()}/optimizer.pt", weights_only=False),
torch.load(f"{self.experiment_folder()}/scheduler.pt", weights_only=False),
torch.load(f"{self.experiment_folder()}/criterion.pt", weights_only=False)
)

def recover_from(self, experiment_id: str) -> Self:
self._experiment_id = experiment_id
self._metrics = self.load_metrics()
self._tracker = self.load_tracker()
self._unrecoverable = None
return self

def continue_training(self, num_epochs: int) -> None:
self._unrecoverable = None
self.train(num_epochs, **self.load_training_arguments())

# Getters
Expand Down Expand Up @@ -365,7 +380,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
example_input = padding_module(example_input)
example_shape = tuple(example_input.shape[1:])
self.log(f"Example input shape: {example_shape}")
toolbox = self.build_toolbox(num_epochs, example_shape)
toolbox = self.load_toolbox(example_shape) if self._unrecoverable is None else self.build_toolbox(num_epochs,
example_shape)
model_name = toolbox.model.__class__.__name__
sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device)
self.log(f"Model: {model_name}")
Expand Down