Skip to content

Commit 1168a43

Browse files
committed
Fix default parameters
1 parent 3309211 commit 1168a43

File tree

7 files changed

+69
-67
lines changed

7 files changed

+69
-67
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Run the training script via
4242
```bash
4343
python run_train.py --dataset <...> --data-dir <...> --save-dir <...>
4444
```
45-
You can see all available training options by running
45+
Hyperparameter defaults are set to the values from the paper. Depending on the dataset, you have to adjust the number of epochs (`--num-epochs`) and the scheduler step size (`--lr-step`), see appendix A of the paper. You can see all available training options by running
4646
```bash
4747
python run_train.py -h
4848
```

arguments/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .train import parser as train_parser
2+
from .eval import parser as eval_parser

arguments/eval.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import configargparse
2+
3+
parser = configargparse.ArgumentParser()
4+
parser.add_argument('-c', '--config', is_config_file=True, help='Path to the config file', type=str)
5+
6+
parser.add_argument('--checkpoint', type=str, required=True, help='Checkpoint path to evaluate')
7+
parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset')
8+
parser.add_argument('--data-dir', type=str, required=True, help='Root directory of the dataset')
9+
parser.add_argument('--num-workers', type=int, default=8, metavar='N', help='Number of dataloader worker processes')
10+
parser.add_argument('--batch-size', type=int, default=8)
11+
parser.add_argument('--crop-size', type=int, default=256, help='Size of the input (squared) patches')
12+
parser.add_argument('--scaling', type=int, default=8, help='Scaling factor')
13+
parser.add_argument('--in-memory', default=False, action='store_true', help='Hold data in memory during evaluation')
14+
parser.add_argument('--feature-extractor', type=str, default='UResNet', help='Feature extractor for edge potentials')

arguments/train.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import configargparse
2+
3+
parser = configargparse.ArgumentParser()
4+
parser.add_argument('-c', '--config', is_config_file=True, help='Path to the config file', type=str)
5+
6+
# general
7+
parser.add_argument('--save-dir', required=True, help='Path to directory where models and logs should be saved')
8+
parser.add_argument('--logstep-train', default=10, type=int, help='Training log interval in steps')
9+
parser.add_argument('--save-model', default='both', choices=['last', 'best', 'no', 'both'])
10+
parser.add_argument('--val-every-n-epochs', type=int, default=1, help='Validation interval in epochs')
11+
parser.add_argument('--resume', type=str, default=None, help='Checkpoint path to resume')
12+
parser.add_argument('--seed', type=int, default=12345, help='Random seed')
13+
parser.add_argument('--wandb', action='store_true', default=False, help='Use Weights & Biases instead of TensorBoard')
14+
parser.add_argument('--wandb-project', type=str, default='graph-sr', help='Wandb project name')
15+
16+
# data
17+
parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset')
18+
parser.add_argument('--data-dir', type=str, required=True, help='Root directory of the dataset')
19+
parser.add_argument('--num-workers', type=int, default=8, metavar='N', help='Number of dataloader worker processes')
20+
parser.add_argument('--batch-size', type=int, default=8)
21+
parser.add_argument('--crop-size', type=int, default=256, help='Size of the input (squared) patches')
22+
parser.add_argument('--scaling', type=int, default=8, help='Scaling factor')
23+
parser.add_argument('--max-rotation', type=float, default=15., help='Maximum rotation angle (degrees)')
24+
parser.add_argument('--no-flip', action='store_true', default=False, help='Switch off random flipping')
25+
parser.add_argument('--in-memory', action='store_true', default=False, help='Hold data in memory during training')
26+
27+
# training
28+
parser.add_argument('--loss', default='l1', type=str, choices=['l1', 'mse'])
29+
parser.add_argument('--num-epochs', type=int, default=250)
30+
parser.add_argument('--optimizer', default='adam', choices=['sgd', 'adam'])
31+
parser.add_argument('--lr', type=float, default=0.0001)
32+
parser.add_argument('--momentum', type=float, default=0.9)
33+
parser.add_argument('--w-decay', type=float, default=1e-5)
34+
parser.add_argument('--lr-scheduler', type=str, default='step', choices=['no', 'step', 'plateau'])
35+
parser.add_argument('--lr-step', type=int, default=10, help='LR scheduler step size (epochs)')
36+
parser.add_argument('--lr-gamma', type=float, default=0.9, help='LR decay rate')
37+
parser.add_argument('--skip-first', action='store_true', help='Don\'t optimize during first epoch')
38+
parser.add_argument('--gradient-clip', type=float, default=0.01, help='If > 0, clips gradient norm to that value')
39+
40+
# model
41+
parser.add_argument('--feature-extractor', type=str, default='UResNet', help='Feature extractor for edge potentials')
42+
parser.add_argument('--lambda-init', type=float, default=1., help='Graph lambda parameter initialization')
43+
parser.add_argument('--mu-init', type=float, default=.1, help='Graph mu parameter initialization')

model/graph_sr_net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
scaling: int,
4747
crop_size=256,
4848
feature_extractor='UResNet',
49-
pretrained=False,
49+
pretrained=True,
5050
lambda_init=1.0,
5151
mu_init=0.1
5252
):

run_eval.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,16 @@
33
from collections import defaultdict
44
import time
55

6-
import configargparse
76
import torch
87
from torchvision.transforms import Normalize
98
from torch.utils.data import DataLoader
109
from tqdm import tqdm
1110

11+
from arguments import eval_parser
1212
from model import GraphSuperResolutionNet
1313
from data import MiddleburyDataset, NYUv2Dataset, DIMLDataset
1414
from utils import to_cuda
1515

16-
parser = configargparse.ArgumentParser()
17-
parser.add_argument('-c', '--config', is_config_file=True, help='Path to the config file', type=str)
18-
parser.add_argument('--checkpoint', type=str, required=True, help='Checkpoint path to evaluate')
19-
parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset')
20-
parser.add_argument('--data-dir', type=str, required=True, help='Root directory of the dataset')
21-
parser.add_argument('--num-workers', type=int, default=8, metavar='N', help='Number of dataloader worker processes')
22-
parser.add_argument('--batch-size', type=int, default=8)
23-
parser.add_argument('--crop-size', type=int, default=256, help='Size of the input (squared) patches')
24-
parser.add_argument('--scaling', type=int, default=8, help='Scaling factor')
25-
parser.add_argument('--in-memory', default=False, action='store_true', help='Hold data in memory during evaluation')
26-
parser.add_argument('--feature-extractor', type=str, default='UResNet', help='Feature extractor for edge potentials')
27-
2816

2917
class Evaluator:
3018

@@ -96,8 +84,8 @@ def resume(self, path):
9684

9785

9886
if __name__ == '__main__':
99-
args = parser.parse_args()
100-
print(parser.format_values())
87+
args = eval_parser.parse_args()
88+
print(eval_parser.format_values())
10189

10290
evaluator = Evaluator(args)
10391

run_train.py

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections import defaultdict
44
import time
55

6-
import configargparse
76
import numpy as np
87
import torch
98
from torch import optim
@@ -13,54 +12,11 @@
1312
from torchvision.transforms import Normalize
1413
from tqdm import tqdm
1514

15+
from arguments import train_parser
1616
from model import GraphSuperResolutionNet
1717
from data import MiddleburyDataset, NYUv2Dataset, DIMLDataset
1818
from utils import new_log, to_cuda, seed_all
1919

20-
parser = configargparse.ArgumentParser()
21-
parser.add_argument('-c', '--config', is_config_file=True, help='Path to the config file', type=str)
22-
23-
# general
24-
parser.add_argument('--save-dir', required=True, help='Path to directory where models and logs should be saved')
25-
parser.add_argument('--logstep-train', default=10, type=int, help='Training log interval in steps')
26-
parser.add_argument('--save-model', default='both', choices=['last', 'best', 'no', 'both'])
27-
parser.add_argument('--val-every-n-epochs', type=int, default=1, help='Validation interval in epochs')
28-
parser.add_argument('--resume', type=str, default=None, help='Checkpoint path to resume')
29-
parser.add_argument('--seed', type=int, default=12345, help='Random seed')
30-
parser.add_argument('--wandb', action='store_true', default=False, help='Use Weights & Biases instead of TensorBoard')
31-
parser.add_argument('--wandb-project', type=str, default='graph-sr', help='Wandb project name')
32-
33-
# data
34-
parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset')
35-
parser.add_argument('--data-dir', type=str, required=True, help='Root directory of the dataset')
36-
parser.add_argument('--num-workers', type=int, default=8, metavar='N', help='Number of dataloader worker processes')
37-
parser.add_argument('--batch-size', type=int, default=8)
38-
parser.add_argument('--crop-size', type=int, default=256, help='Size of the input (squared) patches')
39-
parser.add_argument('--scaling', type=int, default=8, help='Scaling factor')
40-
parser.add_argument('--max-rotation', type=float, default=15., help='Maximum rotation angle (degrees)')
41-
parser.add_argument('--no-flip', action='store_true', default=False, help='Switch off random flipping')
42-
parser.add_argument('--in-memory', action='store_true', default=False, help='Hold data in memory during training')
43-
44-
# training
45-
parser.add_argument('--loss', default='l1', type=str, choices=['l1', 'mse'])
46-
parser.add_argument('--num-epochs', type=int, default=250)
47-
parser.add_argument('--optimizer', default='adam', choices=['sgd', 'adam'])
48-
parser.add_argument('--lr', type=float, default=0.0001)
49-
parser.add_argument('--momentum', type=float, default=0.9)
50-
parser.add_argument('--w-decay', type=float, default=1e-5)
51-
parser.add_argument('--lr-scheduler', type=str, default='step', choices=['no', 'step', 'plateau'])
52-
parser.add_argument('--lr-step', type=int, default=10, help='LR scheduler step size (epochs)')
53-
parser.add_argument('--lr-gamma', type=float, default=0.9, help='LR decay rate')
54-
parser.add_argument('--skip-first', action='store_true', help='Don\'t optimize during first epoch')
55-
parser.add_argument('--gradient-clip', type=float, default=0., help='If > 0, clips gradient norm to that value')
56-
57-
# model
58-
parser.add_argument('--feature-extractor', type=str, default='UResNet', help='Feature extractor for edge potentials')
59-
parser.add_argument('--pretrained', action='store_true', help='Initialize feature extractor with weights '
60-
'pretrained on ImageNet')
61-
parser.add_argument('--lambda-init', type=float, default=1., help='Graph lambda parameter initialization')
62-
parser.add_argument('--mu-init', type=float, default=.1, help='Graph mu parameter initialization')
63-
6420

6521
class Trainer:
6622

@@ -76,9 +32,8 @@ def __init__(self, args: argparse.Namespace):
7632
args.scaling,
7733
args.crop_size,
7834
args.feature_extractor,
79-
args.pretrained,
80-
args.lambda_init,
81-
args.mu_init
35+
lambda_init=args.lambda_init,
36+
mu_init=args.mu_init
8237
)
8338
self.model.cuda()
8439

@@ -272,8 +227,8 @@ def resume(self, path):
272227

273228

274229
if __name__ == '__main__':
275-
args = parser.parse_args()
276-
print(parser.format_values())
230+
args = train_parser.parse_args()
231+
print(train_parser.format_values())
277232

278233
if args.wandb:
279234
import wandb

0 commit comments

Comments
 (0)