-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgbdtmo_catboost.py
108 lines (85 loc) · 2.04 KB
/
gbdtmo_catboost.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
from joblib import Parallel, delayed
from multiprocessing import Queue
from utils import run_cv_loop
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--njobs', type=int)
NAME = 'GBDTMO_BESTK'
RUNNER_PATH = 'runner.py'
DEBUG = True
TIMEOUT = 3600 * 24
HARD_TIMEOUT = 3600 * 36
NTHREADS = 8
SEED = 42
# benchmark_params
data_path = 'data/processed'
benchmark_path = 'runs'
def get_baseline(name, benchmark_path, data_path, dataset, task, params, rewrite=False):
gpu = q.get()
run_cv_loop(name, gpu, benchmark_path, data_path, dataset, task, 'default', params, rewrite=rewrite)
q.put(gpu)
params = [
# caltech
{
'lr': 0.1,
'max_bin': 32,
'max_depth': 10,
'min_data_in_leaf': 16,
'lambda_l2': 1,
'es': 25,
'ntrees': 8000,
'cpu': True,
'subsample': 1,
'acc': True,
},
# nuswide
{
'lr': 0.1,
'max_bin': 64,
'max_depth': 8,
'min_data_in_leaf': 4,
'lambda_l2': 1,
'es': 25,
'ntrees': 8000,
'cpu': True,
'subsample': 1,
'acc': True,
},
# mnist
{
'lr': 0.1,
'max_bin': 8,
'max_depth': 8,
'min_data_in_leaf': 16,
'lambda_l2': 1,
'es': 25,
'ntrees': 8000,
'cpu': True,
'subsample': 1,
'acc': True,
},
# mnist reg
{
'lr': 0.1,
'max_bin': 16,
'max_depth': 7,
'min_data_in_leaf': 4,
'lambda_l2': 1,
'es': 25,
'ntrees': 8000,
'cpu': True,
'subsample': 1,
'acc': True,
}
]
if __name__ == '__main__':
args = parser.parse_args()
q = Queue(maxsize=args.njobs)
for i in range(args.njobs):
q.put([i])
# get runs
Parallel(n_jobs=args.njobs, backend="threading")(
delayed(get_baseline)(
NAME, benchmark_path, data_path, d, 'cb', p, rewrite=True) for (d, p) in
zip(['caltech', 'nuswide', 'mnist', 'mnistreg'], params)
)