Skip to content

Commit

Permalink
add a simple checkpointing system for single node use
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 21, 2022
1 parent dc3ba2a commit 2535012
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,14 @@ More the reason why you should start training your own model, starting today! Th
- [x] build out CLI tool and one-line generation of image
- [x] knock out any issues that arised from accelerate
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] build a simple checkpointing system, backed by a folder
- [ ] add advanced checkpointing system that can be backed to a google bucket or s3
- [ ] investigate https://arxiv.org/abs/2005.09007 architecture in context of ddpm
- [ ] build out CLI tool for training, resuming training off config file
- [ ] preencoding of text to memmapped embeddings
- [ ] extend to video generation, using axial time attention as in Ho's video ddpm paper + https://github.com/lucidrains/flexible-diffusion-modeling-videos-pytorch for up to 25 minute video
- [ ] be able to create dataloader iterators based on the old epoch style, also configure shuffling etc
- [ ] be able to also pass in arguments (instead of requiring forward to be all keyword args on model)
- [ ] build a simple checkpointing system, backed by a folder

## Citations

Expand Down
76 changes: 75 additions & 1 deletion imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
import copy
from pathlib import Path
Expand Down Expand Up @@ -208,6 +209,9 @@ def __init__(
split_valid_fraction = 0.025,
split_valid_from_train = False,
split_random_seed = 42,
checkpoint_path = None,
checkpoint_every = None,
max_checkpoints_keep = 20,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -325,6 +329,19 @@ def __init__(
self.imagen.to(self.device)
self.to(self.device)

# checkpointing

assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
self.checkpoint_path = checkpoint_path
self.checkpoint_every = checkpoint_every
self.max_checkpoints_keep = max_checkpoints_keep

if exists(checkpoint_path) and self.is_local_main:
self.checkpoint_path = Path(checkpoint_path)
self.checkpoint_path.mkdir(exist_ok = True, parents = True)

self.load_from_checkpoint_folder()

# only allowing training for unet

self.only_train_unet_number = only_train_unet_number
Expand Down Expand Up @@ -536,6 +553,53 @@ def step_with_dl_iter(self, dl_iter, **kwargs):
loss = self.forward(**{**kwargs, **model_input})
return loss

# checkpointing functions

@property
def all_checkpoints_sorted(self):
checkpoints = [*self.checkpoint_path.glob('*.pt')]
sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
return sorted_checkpoints

def load_from_checkpoint_folder(self, last_total_steps = -1):
if last_total_steps != -1:
filepath = str(self.checkpoint_path / f'checkpoint.{last_total_steps}.pt')
self.load(str(filepath))
return

sorted_checkpoints = self.all_checkpoints_sorted

if len(sorted_checkpoints) == 0:
self.print(f'no checkpoints found to load from at {str(self.checkpoint_path)}')
return

last_checkpoint = sorted_checkpoints[0]
self.load(str(last_checkpoint))

self.print(f'loading checkpoint from {str(last_checkpoint)}')

def save_to_checkpoint_folder(self):
self.accelerator.wait_for_everyone()

if not self.is_local_main:
return

total_steps = int(self.steps.sum().item())
filepath = self.checkpoint_path / f'checkpoint.{total_steps}.pt'

self.save(filepath)

self.print(f'saved checkpoint to {str(filepath)}')

if self.max_checkpoints_keep <= 0:
return

sorted_checkpoints = self.all_checkpoints_sorted
checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]

for checkpoint in checkpoints_to_discard:
os.remove(str(checkpoint))

# saving and loading functions

def save(self, path, overwrite = True, **kwargs):
Expand Down Expand Up @@ -602,7 +666,7 @@ def load(self, path, only_model = False, strict = True, noop_if_not_exist = Fals
self.print(f'trainer checkpoint not found at {str(path)}')
return

assert path.exists()
assert path.exists(), f'{str(path)} does not exist'

self.reset_ema_unets_all_one_device()

Expand Down Expand Up @@ -774,6 +838,16 @@ def update(self, unet_number = None):

self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))

if not exists(self.checkpoint_path):
return

total_steps = int(self.steps.sum().item())

if total_steps % self.checkpoint_every:
return

self.save_to_checkpoint_folder()

@torch.no_grad()
@cast_torch_tensor
@imagen_sample_in_chunks
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.3'
__version__ = '1.2.4'

0 comments on commit 2535012

Please sign in to comment.