3
3
from collections import defaultdict
4
4
import time
5
5
6
- import configargparse
7
6
import numpy as np
8
7
import torch
9
8
from torch import optim
13
12
from torchvision .transforms import Normalize
14
13
from tqdm import tqdm
15
14
15
+ from arguments import train_parser
16
16
from model import GraphSuperResolutionNet
17
17
from data import MiddleburyDataset , NYUv2Dataset , DIMLDataset
18
18
from utils import new_log , to_cuda , seed_all
19
19
20
- parser = configargparse .ArgumentParser ()
21
- parser .add_argument ('-c' , '--config' , is_config_file = True , help = 'Path to the config file' , type = str )
22
-
23
- # general
24
- parser .add_argument ('--save-dir' , required = True , help = 'Path to directory where models and logs should be saved' )
25
- parser .add_argument ('--logstep-train' , default = 10 , type = int , help = 'Training log interval in steps' )
26
- parser .add_argument ('--save-model' , default = 'both' , choices = ['last' , 'best' , 'no' , 'both' ])
27
- parser .add_argument ('--val-every-n-epochs' , type = int , default = 1 , help = 'Validation interval in epochs' )
28
- parser .add_argument ('--resume' , type = str , default = None , help = 'Checkpoint path to resume' )
29
- parser .add_argument ('--seed' , type = int , default = 12345 , help = 'Random seed' )
30
- parser .add_argument ('--wandb' , action = 'store_true' , default = False , help = 'Use Weights & Biases instead of TensorBoard' )
31
- parser .add_argument ('--wandb-project' , type = str , default = 'graph-sr' , help = 'Wandb project name' )
32
-
33
- # data
34
- parser .add_argument ('--dataset' , type = str , required = True , help = 'Name of the dataset' )
35
- parser .add_argument ('--data-dir' , type = str , required = True , help = 'Root directory of the dataset' )
36
- parser .add_argument ('--num-workers' , type = int , default = 8 , metavar = 'N' , help = 'Number of dataloader worker processes' )
37
- parser .add_argument ('--batch-size' , type = int , default = 8 )
38
- parser .add_argument ('--crop-size' , type = int , default = 256 , help = 'Size of the input (squared) patches' )
39
- parser .add_argument ('--scaling' , type = int , default = 8 , help = 'Scaling factor' )
40
- parser .add_argument ('--max-rotation' , type = float , default = 15. , help = 'Maximum rotation angle (degrees)' )
41
- parser .add_argument ('--no-flip' , action = 'store_true' , default = False , help = 'Switch off random flipping' )
42
- parser .add_argument ('--in-memory' , action = 'store_true' , default = False , help = 'Hold data in memory during training' )
43
-
44
- # training
45
- parser .add_argument ('--loss' , default = 'l1' , type = str , choices = ['l1' , 'mse' ])
46
- parser .add_argument ('--num-epochs' , type = int , default = 250 )
47
- parser .add_argument ('--optimizer' , default = 'adam' , choices = ['sgd' , 'adam' ])
48
- parser .add_argument ('--lr' , type = float , default = 0.0001 )
49
- parser .add_argument ('--momentum' , type = float , default = 0.9 )
50
- parser .add_argument ('--w-decay' , type = float , default = 1e-5 )
51
- parser .add_argument ('--lr-scheduler' , type = str , default = 'step' , choices = ['no' , 'step' , 'plateau' ])
52
- parser .add_argument ('--lr-step' , type = int , default = 10 , help = 'LR scheduler step size (epochs)' )
53
- parser .add_argument ('--lr-gamma' , type = float , default = 0.9 , help = 'LR decay rate' )
54
- parser .add_argument ('--skip-first' , action = 'store_true' , help = 'Don\' t optimize during first epoch' )
55
- parser .add_argument ('--gradient-clip' , type = float , default = 0. , help = 'If > 0, clips gradient norm to that value' )
56
-
57
- # model
58
- parser .add_argument ('--feature-extractor' , type = str , default = 'UResNet' , help = 'Feature extractor for edge potentials' )
59
- parser .add_argument ('--pretrained' , action = 'store_true' , help = 'Initialize feature extractor with weights '
60
- 'pretrained on ImageNet' )
61
- parser .add_argument ('--lambda-init' , type = float , default = 1. , help = 'Graph lambda parameter initialization' )
62
- parser .add_argument ('--mu-init' , type = float , default = .1 , help = 'Graph mu parameter initialization' )
63
-
64
20
65
21
class Trainer :
66
22
@@ -76,9 +32,8 @@ def __init__(self, args: argparse.Namespace):
76
32
args .scaling ,
77
33
args .crop_size ,
78
34
args .feature_extractor ,
79
- args .pretrained ,
80
- args .lambda_init ,
81
- args .mu_init
35
+ lambda_init = args .lambda_init ,
36
+ mu_init = args .mu_init
82
37
)
83
38
self .model .cuda ()
84
39
@@ -272,8 +227,8 @@ def resume(self, path):
272
227
273
228
274
229
if __name__ == '__main__' :
275
- args = parser .parse_args ()
276
- print (parser .format_values ())
230
+ args = train_parser .parse_args ()
231
+ print (train_parser .format_values ())
277
232
278
233
if args .wandb :
279
234
import wandb
0 commit comments