Skip to content

Commit feaef4f

Browse files
authored
Merge pull request #15 from automl/release_0.0.2
Release 0.0.2
2 parents 6996050 + 8cd4e0e commit feaef4f

File tree

217 files changed

+9343
-175
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

217 files changed

+9343
-175
lines changed

.gitignore

+15-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
21
# Visual Studio
32
*.vs/*
43

54
# Visual Studio Code
65
*.vscode/*
76

87
# Python
9-
*__pycache__/
8+
*__pycache__*
109
*.pyc
10+
.ipynb_checkpoints*
1111

1212
# Zipped
1313
*.tar.gz
@@ -24,19 +24,27 @@ results.json
2424
outputs/
2525
jobs.txt
2626
.pylintrc
27+
*worker_logs*
2728

2829
# Build
2930
*build/
30-
*autonet.egg-info
31+
*autoPyTorch.egg-info
3132
*.simg
32-
33-
34-
# Datasets
35-
/datasets/
33+
.DS_Store
34+
dist/
3635

3736
# Meta GPU
3837
*meta_logs/
38+
runs.log
39+
runs.log.lock
40+
logs/
3941

4042
# ensemble data
4143
predictions_for_ensemble.npy
4244
test_predictions_for_ensemble.npy
45+
46+
# testing
47+
tests.ipynb
48+
49+
# venv
50+
env/

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Copyright (C) 2019 [AutoML Group Freiburg](http://www.automl.org/)
44

55
This a very early pre-alpha version of our upcoming Auto-PyTorch.
6-
So far, Auto-PyTorch only supports featurized data.
6+
So far, Auto-PyTorch supports featurized data (classification, regression) and image data (classification).
77

88
## Installation
99

@@ -33,6 +33,8 @@ $ python setup.py install
3333

3434
## Examples
3535

36+
For a detailed tutorial, please refer to the jupyter notebook in https://github.com/automl/Auto-PyTorch/tree/master/examples/basics.
37+
3638
In a nutshell:
3739

3840
```py
@@ -112,7 +114,7 @@ search_space_updates.append(node_name="NetworkSelector",
112114
autoPyTorch = AutoNetClassification(hyperparameter_search_space_updates=search_space_updates)
113115
```
114116

115-
Enable ensemble building:
117+
Enable ensemble building (for featurized data):
116118

117119
```py
118120
from autoPyTorch import AutoNetEnsemble

autoPyTorch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
hpbandster = os.path.abspath(os.path.join(__file__, '..', '..', 'submodules', 'HpBandSter'))
33
sys.path.append(hpbandster)
44

5-
from autoPyTorch.core.autonet_classes import AutoNetClassification, AutoNetMultilabel, AutoNetRegression
5+
from autoPyTorch.core.autonet_classes import AutoNetClassification, AutoNetMultilabel, AutoNetRegression, AutoNetImageClassification, AutoNetImageClassificationMultipleDatasets
66
from autoPyTorch.data_management.data_manager import DataManager
77
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
88
from autoPyTorch.core.ensemble import AutoNetEnsemble

autoPyTorch/components/lr_scheduler/lr_schedulers.py

+190-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
from autoPyTorch.utils.config_space_hyperparameter import add_hyperparameter, get_hyperparameter
88

9+
import numpy as np
10+
import math
911
import torch
1012
import torch.optim.lr_scheduler as lr_scheduler
13+
from torch.optim import Optimizer
1114

1215
import ConfigSpace as CS
1316
import ConfigSpace.hyperparameters as CSH
@@ -16,6 +19,7 @@
1619
__version__ = "0.0.1"
1720
__license__ = "BSD"
1821

22+
1923
class AutoNetLearningRateSchedulerBase(object):
2024
def __new__(cls, optimizer, config):
2125
"""Get a new instance of the scheduler
@@ -42,12 +46,17 @@ def _get_scheduler(self, optimizer, config):
4246
def get_config_space():
4347
return CS.ConfigurationSpace()
4448

49+
4550
class SchedulerNone(AutoNetLearningRateSchedulerBase):
4651

4752
def _get_scheduler(self, optimizer, config):
4853
return NoScheduling(optimizer=optimizer)
4954

55+
5056
class SchedulerStepLR(AutoNetLearningRateSchedulerBase):
57+
"""
58+
Step learning rate scheduler
59+
"""
5160

5261
def _get_scheduler(self, optimizer, config):
5362
return lr_scheduler.StepLR(optimizer=optimizer, step_size=config['step_size'], gamma=config['gamma'], last_epoch=-1)
@@ -62,8 +71,12 @@ def get_config_space(
6271
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'gamma', gamma)
6372
return cs
6473

74+
6575
class SchedulerExponentialLR(AutoNetLearningRateSchedulerBase):
66-
76+
"""
77+
Exponential learning rate scheduler
78+
"""
79+
6780
def _get_scheduler(self, optimizer, config):
6881
return lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=config['gamma'], last_epoch=-1)
6982

@@ -75,11 +88,17 @@ def get_config_space(
7588
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'gamma', gamma)
7689
return cs
7790

91+
7892
class SchedulerReduceLROnPlateau(AutoNetLearningRateSchedulerBase):
93+
"""
94+
Reduce LR on plateau learning rate scheduler
95+
"""
7996

8097
def _get_scheduler(self, optimizer, config):
81-
return lr_scheduler.ReduceLROnPlateau(optimizer=optimizer)
82-
98+
return lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
99+
factor=config['factor'],
100+
patience=config['patience'])
101+
83102
@staticmethod
84103
def get_config_space(
85104
factor=(0.05, 0.5),
@@ -90,7 +109,112 @@ def get_config_space(
90109
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'patience', patience)
91110
return cs
92111

112+
113+
class SchedulerAdaptiveLR(AutoNetLearningRateSchedulerBase):
114+
"""
115+
Adaptive cosine learning rate scheduler
116+
"""
117+
118+
def _get_scheduler(self, optimizer, config):
119+
return AdaptiveLR(optimizer=optimizer,
120+
T_max=config['T_max'],
121+
T_mul=config['T_mult'],
122+
patience=config['patience'],
123+
threshold=config['threshold'])
124+
125+
@staticmethod
126+
def get_config_space(
127+
T_max=(300,1000),
128+
patience=(2,5),
129+
T_mult=(1.0,2.0),
130+
threshold=(0.001, 0.5)
131+
):
132+
cs = CS.ConfigurationSpace()
133+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'T_max', T_max)
134+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'patience', patience)
135+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'T_mult', T_mult)
136+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'threshold', threshold)
137+
return cs
138+
139+
140+
class AdaptiveLR(object):
141+
142+
def __init__(self, optimizer, mode='min', T_max=30, T_mul=2.0, eta_min=0, patience=3, threshold=0.1, min_lr=0, eps=1e-8, last_epoch=-1):
143+
144+
if not isinstance(optimizer, Optimizer):
145+
raise TypeError('{} is not an Optimizer'.format(
146+
type(optimizer).__name__))
147+
148+
self.optimizer = optimizer
149+
150+
if last_epoch == -1:
151+
for group in optimizer.param_groups:
152+
group.setdefault('initial_lr', group['lr'])
153+
else:
154+
for i, group in enumerate(optimizer.param_groups):
155+
if 'initial_lr' not in group:
156+
raise KeyError("param 'initial_lr' is not specified "
157+
"in param_groups[{}] when resuming an optimizer".format(i))
158+
159+
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
160+
self.last_epoch = last_epoch
161+
162+
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
163+
if len(min_lr) != len(optimizer.param_groups):
164+
raise ValueError("expected {} min_lrs, got {}".format(
165+
len(optimizer.param_groups), len(min_lr)))
166+
self.min_lrs = list(min_lr)
167+
else:
168+
self.min_lrs = [min_lr] * len(optimizer.param_groups)
169+
170+
self.T_max = T_max
171+
self.T_mul = T_mul
172+
self.eta_min = eta_min
173+
self.current_base_lrs = self.base_lrs
174+
self.metric_values = []
175+
self.threshold = threshold
176+
self.patience = patience
177+
self.steps = 0
178+
179+
def step(self, metrics, epoch=None):
180+
if epoch is None:
181+
epoch = self.last_epoch + 1
182+
self.last_epoch = epoch
183+
184+
self.metric_values.append(metrics)
185+
if len(self.metric_values) > self.patience:
186+
self.metric_values = self.metric_values[1:]
187+
188+
if max(self.metric_values) - metrics > self.threshold:
189+
self.current_base_lrs = self.get_lr()
190+
self.steps = 0
191+
else:
192+
self.steps += 1
193+
194+
self.last_metric_value = metrics
195+
196+
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
197+
param_group['lr'] = lr
198+
199+
def get_lr(self):
200+
'''
201+
Override this method to the existing get_lr() of the parent class
202+
'''
203+
if self.steps >= self.T_max:
204+
self.T_max = self.T_max * self.T_mul
205+
self.current_base_lrs = self.base_lrs
206+
self.metric_values = []
207+
self.steps = 0
208+
209+
return [self.eta_min + (base_lr - self.eta_min) *
210+
(1 + math.cos(math.pi * self.steps / self.T_max)) / 2
211+
for base_lr in self.current_base_lrs]
212+
213+
93214
class SchedulerCyclicLR(AutoNetLearningRateSchedulerBase):
215+
"""
216+
Cyclic learning rate scheduler
217+
"""
94218

95219
def _get_scheduler(self, optimizer, config):
96220
maf = config['max_factor']
@@ -118,7 +242,11 @@ def get_config_space(
118242
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'cycle_length', cycle_length)
119243
return cs
120244

245+
121246
class SchedulerCosineAnnealingWithRestartsLR(AutoNetLearningRateSchedulerBase):
247+
"""
248+
Cosine annealing learning rate scheduler with warm restarts
249+
"""
122250

123251
def _get_scheduler(self, optimizer, config):
124252
scheduler = CosineAnnealingWithRestartsLR(optimizer, T_max=config['T_max'], T_mult=config['T_mult'],last_epoch=-1)
@@ -151,7 +279,6 @@ def get_lr(self):
151279
return [None]
152280

153281

154-
import math
155282
class CosineAnnealingWithRestartsLR(torch.optim.lr_scheduler._LRScheduler):
156283

157284
r"""Copyright: pytorch
@@ -205,3 +332,62 @@ def get_lr(self):
205332
if self.step_n >= self.restart_every:
206333
self.restart()
207334
return [self.cosine(base_lr) for base_lr in self.base_lrs]
335+
336+
def needs_checkpoint(self):
337+
return self.step_n + 1 >= self.restart_every
338+
339+
340+
class SchedulerAlternatingCosineLR(AutoNetLearningRateSchedulerBase):
341+
"""
342+
Alternating cosine learning rate scheduler
343+
"""
344+
345+
def _get_scheduler(self, optimizer, config):
346+
scheduler = AlternatingCosineLR(optimizer, T_max=config['T_max'], T_mul=config['T_mult'], amplitude_reduction=config['amp_reduction'], last_epoch=-1)
347+
return scheduler
348+
349+
@staticmethod
350+
def get_config_space(
351+
T_max=(1, 20),
352+
T_mult=(1.0, 2.0),
353+
amp_reduction=(0.1,1)
354+
):
355+
cs = CS.ConfigurationSpace()
356+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'T_max', T_max)
357+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'T_mult', T_mult)
358+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'amp_reduction', amp_reduction)
359+
return cs
360+
361+
362+
class AlternatingCosineLR(torch.optim.lr_scheduler._LRScheduler):
363+
def __init__(self, optimizer, T_max, T_mul=1, amplitude_reduction=0.9, eta_min=0, last_epoch=-1):
364+
'''
365+
Here last_epoch actually means last_step since the
366+
learning rate is decayed after each batch step.
367+
'''
368+
369+
self.T_max = T_max
370+
self.T_mul = T_mul
371+
self.eta_min = eta_min
372+
self.cumulative_time = 0
373+
self.amplitude_mult = amplitude_reduction
374+
self.base_lr_mult = 1
375+
self.frequency_mult = 1
376+
self.time_offset = 0
377+
self.last_step = 0
378+
super(AlternatingCosineLR, self).__init__(optimizer, last_epoch)
379+
380+
def get_lr(self):
381+
'''
382+
Override this method to the existing get_lr() of the parent class
383+
'''
384+
if self.last_epoch >= self.T_max:
385+
self.T_max = self.T_max * self.T_mul
386+
self.time_offset = self.T_max / 2
387+
self.last_epoch = 0
388+
self.base_lr_mult *= self.amplitude_mult
389+
self.frequency_mult = 2
390+
self.cumulative_time = 0
391+
return [self.eta_min + (base_lr * self.base_lr_mult - self.eta_min) *
392+
(1 + math.cos(math.pi * (self.time_offset + self.cumulative_time) / self.T_max * self.frequency_mult)) / 2
393+
for base_lr in self.base_lrs]
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from autoPyTorch.components.metrics.balanced_accuracy import balanced_accuracy
22
from autoPyTorch.components.metrics.pac_score import pac_metric
3-
from autoPyTorch.components.metrics.standard_metrics import accuracy, auc_metric, mean_distance, multilabel_accuracy
3+
from autoPyTorch.components.metrics.standard_metrics import accuracy, auc_metric, mean_distance, multilabel_accuracy, cross_entropy, top1, top3, top5

0 commit comments

Comments
 (0)