diff --git a/SwinUNETR/Pretrain/main.py b/SwinUNETR/Pretrain/main.py index 93d3aca5..e7ff588a 100644 --- a/SwinUNETR/Pretrain/main.py +++ b/SwinUNETR/Pretrain/main.py @@ -91,13 +91,18 @@ def train(args, global_step, train_loader, val_best, scaler): writer.add_image("Validation/x1_aug", img_list[1], global_step, dataformats="HW") writer.add_image("Validation/x1_recon", img_list[2], global_step, dataformats="HW") + checkpoint = { + "global_step": global_step, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "val_best": val_best, + } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() + if val_loss_recon < val_best: val_best = val_loss_recon - checkpoint = { - "global_step": global_step, - "state_dict": model.state_dict(), - "optimizer": optimizer.state_dict(), - } save_ckp(checkpoint, logdir + "/model_bestValRMSE.pt") print( "Model was saved ! Best Recon. Val Loss: {:.4f}, Recon. Val Loss: {:.4f}".format( @@ -110,6 +115,8 @@ def train(args, global_step, train_loader, val_best, scaler): val_best, val_loss_recon ) ) + + save_ckp(checkpoint, logdir + "/last.pt") return global_step, loss, val_best def validation(args, test_loader): @@ -234,13 +241,6 @@ def validation(args, test_loader): elif args.opt == "sgd": optimizer = optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) - if args.resume: - model_pth = args.resume - model_dict = torch.load(model_pth) - model.load_state_dict(model_dict["state_dict"]) - model.epoch = model_dict["epoch"] - model.optimizer = model_dict["optimizer"] - if args.lrdecay: if args.lr_schedule == "warmup_cosine": scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=args.num_steps) @@ -252,21 +252,45 @@ def lambdas(epoch): scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas) + if args.amp: + scaler = GradScaler(1024) + else: + scaler = None + + global_step = 0 + best_val = 1e8 + if args.resume: + model_pth = args.resume + model_dict = torch.load(model_pth) + model.load_state_dict({k[7:]: v for k, v in model_dict["state_dict"].items()}) + optimizer.load_state_dict(model_dict["optimizer"]) + global_step = model_dict["global_step"] + if "scaler" in model_dict: + scaler.load_state_dict(model_dict["scaler"]) + if "val_best" in model_dict: + best_val = model_dict["val_best"] + if "scheduler" in model_dict: + scheduler.load_state_dict(model_dict["scheduler"]) + else: + scheduler.last_epoch = global_step - 1 + loss_function = Loss(args.batch_size * args.sw_batch_size, args) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DistributedDataParallel(model, device_ids=[args.local_rank]) train_loader, test_loader = get_loader(args) - global_step = 0 - best_val = 1e8 - if args.amp: - scaler = GradScaler() - else: - scaler = None while global_step < args.num_steps: global_step, loss, best_val = train(args, global_step, train_loader, best_val, scaler) - checkpoint = {"epoch": args.epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} + checkpoint = { + "global_step": global_step, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "val_best": best_val + } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() if args.distributed: if dist.get_rank() == 0: