-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathexperiments.py
More file actions
98 lines (77 loc) · 2.61 KB
/
experiments.py
File metadata and controls
98 lines (77 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
from itertools import product
def combinations(grid):
return list(dict(zip(grid.keys(), values)) for values in product(*grid.values()))
def get_hparams(experiment):
if experiment not in globals():
raise NotImplementedError
return globals()[experiment].hparams()
def get_script_name(experiment):
if experiment not in globals():
raise NotImplementedError
return globals()[experiment].fname
#### write experiments here
class Camelyon:
fname = 'train.py'
@staticmethod
def hparams():
grid = {
'type': ['back', 'front', 'back_front', 'label_flip'],
'data': ['camelyon'],
'data_type': ['Conf', 'Deconf', 'DA', 'IF'],
'domains': ((2, 3),),
'corr-coff': list(np.linspace(0.65, 0.95, 4)),
'seed': list(range(5)),
}
return combinations(grid)
class CXR:
fname = 'train.py'
@staticmethod
def hparams():
grid = {
'type': ['back', 'front', 'back_front', 'label_flip'],
'data': ['CXR'],
'data_type': ['Conf', 'Deconf', 'DA', 'IF'],
'corr-coff': list(np.linspace(0.65, 0.95, 4)),
'seed': list(range(5)),
'samples': [6500],
'use_pretrained': [True],
'cache_cxr': [True]
}
return combinations(grid)
class Poverty:
fname = 'train.py'
@staticmethod
def hparams():
grid = {
'type': ['back', 'front', 'back_front', 'label_flip'],
'data': ['poverty'],
'data_type': ['Conf', 'Deconf', 'DA', 'IF'],
'corr-coff': list(np.linspace(0.65, 0.95, 4)),
'seed': list(range(5)),
'samples': [300],
'domains': (('malawi', 'kenya', 'tanzania', 'nigeria'),),
}
return combinations(grid)
class Synthetic:
fname = 'train_synthetic.py'
@staticmethod
def hparams():
hps = []
for corr_coff in list(np.linspace(0.65, 0.95, 4)):
hps += combinations({
'type': ['back', 'front', 'back_front', 'par_back_front', 'label_flip'],
'corr-coff': [corr_coff],
'test-corr': [corr_coff]
})
return hps
class EnvClf:
fname = 'train_env.py'
@staticmethod
def hparams():
grid = {
'data': ['CXR', 'camelyon', 'poverty', 'NIH', 'MNIST', 'CelebA'],
'seed': list(range(5)),
'use_pretrained': [True]
}
return combinations(grid)