Skip to content

Commit

Permalink
add clean checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
innnky committed Jan 17, 2023
1 parent 339aa74 commit 0dfc12a
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import glob
import re
import sys
import argparse
import logging
Expand Down Expand Up @@ -117,12 +118,6 @@ def load_checkpoint(checkpoint_path, model, optimizer=None):


def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
# ckptname = checkpoint_path.split(os.sep)[-1]
# newest_step = int(ckptname.split(".")[0].split("_")[1])
# val_steps = 2000
# last_ckptname = checkpoint_path.replace(str(newest_step), str(newest_step - val_steps*3))
# if newest_step >= val_steps*3:
# os.system(f"rm {last_ckptname}")
logger.info("Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path))
if hasattr(model, 'module'):
Expand All @@ -133,7 +128,29 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'learning_rate': learning_rate}, checkpoint_path)

clean_ckpt = False
if clean_ckpt:
clean_checkpoints(path_to_models='logs/32k/', n_ckpts_to_keep=3, sort_by_time=True)

def clean_checkpoints(path_to_models='logs/48k/', n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
to_del = [os.path.join(path_to_models, fn) for fn in
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]

def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
for k, v in scalars.items():
Expand Down

0 comments on commit 0dfc12a

Please sign in to comment.