Skip to content
3 changes: 3 additions & 0 deletions mipcandy/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
raise NotImplementedError

def build_network_from_checkpoint(self, example_shape: tuple[int, ...], checkpoint: Mapping[str, Any]) -> nn.Module:
"""
Internally exposed interface for overriding. Use `load_model()` instead.
"""
network = self.build_network(example_shape)
network.load_state_dict(checkpoint)
return network
Expand Down
35 changes: 30 additions & 5 deletions mipcandy/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TrainerTracker(object):

class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *,
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True,
device: torch.device | str = "cpu", console: Console = Console()) -> None:
WithPaddingModule.__init__(self, device)
WithNetwork.__init__(self, device)
Expand All @@ -69,6 +69,7 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t
self._experiment_id: str = "tbd"
self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader
self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader
self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered
self._console: Console = console
self._metrics: dict[str, list[float]] = {}
self._epoch_metrics: dict[str, list[float]] = {}
Expand All @@ -80,9 +81,11 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t

def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker,
**training_arguments) -> None:
torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pt")
torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pt")
torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pt")
if self._unrecoverable:
return
torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth")
torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth")
torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth")
torch.save(tracker, f"{self.experiment_folder()}/tracker.pt")
with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f:
dump(training_arguments, f)
Expand All @@ -98,13 +101,28 @@ def load_metrics(self) -> dict[str, list[float]]:
df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch")
return {column: df[column].astype(float).tolist() for column in df.columns}

def load_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox:
toolbox = self.build_toolbox(num_epochs, example_shape)
toolbox.model = self.load_model(
example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth")
)
toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth"))
toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth"))
toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth"))
return toolbox

def recover_from(self, experiment_id: str) -> Self:
self._experiment_id = experiment_id
if not exists(self.experiment_folder()):
raise FileNotFoundError(f"Experiment folder {self.experiment_folder()} not found")
self._metrics = self.load_metrics()
self._tracker = self.load_tracker()
self._unrecoverable = None
return self

def continue_training(self, num_epochs: int) -> None:
if not self.recovery():
raise RuntimeError("Must call `recover_from()` before continuing training")
self.train(num_epochs, **self.load_training_arguments())

# Getters
Expand Down Expand Up @@ -141,6 +159,9 @@ def tracker(self) -> TrainerTracker:
def initialized(self) -> bool:
return self._experiment_id != "tbd"

def recovery(self) -> bool:
return self._unrecoverable is None

def experiment_folder(self) -> str:
return f"{self._trainer_folder}/{self._trainer_variant}/{self._experiment_id}"

Expand Down Expand Up @@ -189,6 +210,9 @@ def allocate_experiment_folder(self) -> str:
return self.experiment_folder() if self.initialized() else self._allocate_experiment_folder()

def init_experiment(self) -> None:
if self.recovery():
self.log(f"Training progress recovered from {self._experiment_id} from epoch {self._tracker.epoch}")
return
if self.initialized():
raise RuntimeError("Experiment already initialized")
makedirs(self._trainer_folder, exist_ok=True)
Expand Down Expand Up @@ -365,7 +389,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
example_input = padding_module(example_input)
example_shape = tuple(example_input.shape[1:])
self.log(f"Example input shape: {example_shape}")
toolbox = self.build_toolbox(num_epochs, example_shape)
toolbox = self.load_toolbox(num_epochs, example_shape) if self.recovery() else self.build_toolbox(
num_epochs, example_shape)
model_name = toolbox.model.__class__.__name__
sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device)
self.log(f"Model: {model_name}")
Expand Down