-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhubconf.py
99 lines (80 loc) · 3.99 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
# Optional list of dependencies required by the package
from fullbatch.models.resnets import resnet_depths_to_config, ResNet
dependencies = ["torch"]
names = ["highreg"]
url = "https://github.com/JonasGeiping/fullbatchtraining/releases/download/v1/"
model_urls = {
"final_fbaug_highreg_lr08_resnet18": url + "final_fbaug_highreg_lr08_resnet18" + ".pth",
"final_fbaug_gradreg_lr08_resnet18": url + "final_fbaug_gradreg_lr08_resnet18" + ".pth",
"final_fbaug_gradreg_lr16_resnet18": url + "final_fbaug_gradreg_lr16_resnet18" + ".pth",
"final_fbaug_clip_lr04_resnet18": url + "final_fbaug_clip_lr04_resnet18" + ".pth",
"final_fbaug_highreg_lr08_shuffle_resnet152": url + "final_fbaug_highreg_lr08_shuffle_resnet152" + ".pth",
}
def _resnet18(name, pretrained=False, progress=True, **kwargs):
r"""ResNet-18 with default config used in this repo"""
# Architecture:
block, layers = resnet_depths_to_config(18)
model = ResNet(
block,
layers,
channels=3,
classes=10,
stem="CIFAR",
convolution_type="Standard",
nonlin="ReLU",
norm="BatchNorm2d",
downsample="C",
width_per_group=64,
zero_init_residual="skip-residual",
)
if pretrained:
_, state_dict, _, _, _ = torch.hub.load_state_dict_from_url(
model_urls[name], progress=progress, map_location=torch.device("cpu")
)
model.load_state_dict(state_dict)
return model
def _resnet152(name, pretrained=False, progress=True, **kwargs):
r"""ResNet-152 with default config used in this repo"""
# Architecture:
block, layers = resnet_depths_to_config(152)
model = ResNet(
block,
layers,
channels=3,
classes=10,
stem="CIFAR",
convolution_type="Standard",
nonlin="ReLU",
norm="BatchNorm2d",
downsample="C",
width_per_group=64,
zero_init_residual="skip-residual",
)
if pretrained:
_, state_dict, _, _, _ = torch.hub.load_state_dict_from_url(
model_urls[name], progress=progress, map_location=torch.device("cpu")
)
model.load_state_dict(state_dict)
return model
def resnet18_fbaug_clip(pretrained=False, progress=True, **kwargs):
r"""Loads a Resnet18 model pretrained with fullbatch gradient descent with "clip" hyperparams
trained with data augmentations but without data shuffling as described in section 3."""
return _resnet18("final_fbaug_clip_lr04_resnet18", pretrained, progress, **kwargs)
def resnet18_fbaug_gradreg(pretrained=False, progress=True, **kwargs):
r"""Loads a Resnet18 model pretrained with fullbatch gradient descent with "gradreg" hyperparams
trained with data augmentations but without data shuffling as described in section 3."""
return _resnet18("final_fbaug_gradreg_lr08_resnet18", pretrained, progress, **kwargs)
def resnet18_fbaug_gradreg_v2(pretrained=False, progress=True, **kwargs):
r"""Loads a Resnet18 model pretrained with fullbatch gradient descent with "gradreg" hyperparams,
but a doubled learning rate compared to the arxiv version,
trained with data augmentations but without data shuffling as described in section 3."""
return _resnet18("final_fbaug_gradreg_lr16_resnet18", pretrained, progress, **kwargs)
def resnet18_fbaug_highreg(pretrained=False, progress=True, **kwargs):
r"""Loads a Resnet18 model pretrained with fullbatch gradient descent with "highreg" hyperparams
trained with data augmentations but without data shuffling as described in section 3."""
return _resnet18("final_fbaug_highreg_lr08_resnet18", pretrained, progress, **kwargs)
def resnet152_fbaug_highreg(pretrained=False, progress=True, **kwargs):
r"""Loads a Resnet152 model pretrained with fullbatch gradient descent with "highreg" hyperparams
trained with data augmentations and data shuffling as described in section 3."""
return _resnet152("final_fbaug_highreg_lr08_shuffle_resnet152", pretrained, progress, **kwargs)