Skip to content

Commit 9f43bbc

Browse files
authored
fix bug for OFA (#464)
* fix bugs for ernie
1 parent c6fdcc3 commit 9f43bbc

File tree

6 files changed

+354
-84
lines changed

6 files changed

+354
-84
lines changed

demo/one_shot/ofa_train.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
import numpy as np
1616
import paddle
17-
import paddle.fluid as fluid
18-
import paddle.fluid.dygraph.nn as nn
17+
import paddle.nn as nn
18+
import paddle.nn.functional as F
1919
from paddle.nn import ReLU
2020
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
2121
from paddleslim.nas.ofa import supernet
2222

2323

24-
class Model(fluid.dygraph.Layer):
24+
class Model(nn.Layer):
2525
def __init__(self):
2626
super(Model, self).__init__()
2727
with supernet(
@@ -50,18 +50,20 @@ def forward(self, inputs, label, depth=None):
5050

5151
for idx, layer in enumerate(models):
5252
if idx == 6:
53-
inputs = fluid.layers.flatten(inputs, 1)
53+
inputs = paddle.flatten(inputs, 1)
5454
inputs = layer(inputs)
5555

56-
inputs = fluid.layers.softmax(inputs)
56+
inputs = F.softmax(inputs)
5757
return inputs
5858

5959

6060
def test_ofa():
6161

62+
model = Model()
63+
teacher_model = Model()
64+
6265
default_run_config = {
6366
'train_batch_size': 256,
64-
'eval_batch_size': 64,
6567
'n_epochs': [[1], [2, 3], [4, 5]],
6668
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
6769
'dynamic_batch_size': [1, 1, 1],
@@ -72,42 +74,46 @@ def test_ofa():
7274

7375
default_distill_config = {
7476
'lambda_distill': 0.01,
75-
'teacher_model': Model,
77+
'teacher_model': teacher_model,
7678
'mapping_layers': ['models.0.fn']
7779
}
7880
distill_config = DistillConfig(**default_distill_config)
7981

80-
fluid.enable_dygraph()
81-
model = Model()
8282
ofa_model = OFA(model, run_config, distill_config=distill_config)
8383

84-
train_reader = paddle.fluid.io.batch(
85-
paddle.dataset.mnist.train(), batch_size=256, drop_last=True)
84+
train_dataset = paddle.vision.datasets.MNIST(
85+
mode='train', backend='cv2', transform=transform)
86+
train_loader = paddle.io.DataLoader(
87+
train_dataset,
88+
places=place,
89+
feed_list=[image, label],
90+
drop_last=True,
91+
batch_size=64)
8692

8793
start_epoch = 0
8894
for idx in range(len(run_config.n_epochs)):
8995
cur_idx = run_config.n_epochs[idx]
9096
for ph_idx in range(len(cur_idx)):
9197
cur_lr = run_config.init_learning_rate[idx][ph_idx]
92-
adam = fluid.optimizer.Adam(
98+
adam = paddle.optimizer.Adam(
9399
learning_rate=cur_lr,
94100
parameter_list=(ofa_model.parameters() + ofa_model.netAs_param))
95101
for epoch_id in range(start_epoch,
96102
run_config.n_epochs[idx][ph_idx]):
97-
for batch_id, data in enumerate(train_reader()):
103+
for batch_id, data in enumerate(train_loader()):
98104
dy_x_data = np.array(
99105
[x[0].reshape(1, 28, 28)
100106
for x in data]).astype('float32')
101107
y_data = np.array(
102108
[x[1] for x in data]).astype('int64').reshape(-1, 1)
103109

104-
img = fluid.dygraph.to_variable(dy_x_data)
105-
label = fluid.dygraph.to_variable(y_data)
110+
img = paddle.dygraph.to_variable(dy_x_data)
111+
label = paddle.dygraph.to_variable(y_data)
106112
label.stop_gradient = True
107113

108114
for model_no in range(run_config.dynamic_batch_size[idx]):
109115
output, _ = ofa_model(img, label)
110-
loss = fluid.layers.reduce_mean(output)
116+
loss = F.mean(output)
111117
dis_loss = ofa_model.calc_distill_loss()
112118
loss += dis_loss
113119
loss.backward()

paddleslim/nas/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .rl_nas import *
2020
from ..nas import darts
2121
from .darts import *
22+
from .ofa import *
2223

2324
__all__ = []
2425
__all__ += sa_nas.__all__

paddleslim/nas/ofa/convert_super.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@
1616
import decorator
1717
import logging
1818
import paddle
19-
import paddle.fluid as fluid
20-
from paddle.fluid import framework
21-
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm
19+
import numbers
20+
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm, LayerNorm, Embedding
2221
from .layers import *
2322
from ...common import get_logger
2423

2524
_logger = get_logger(__name__, level=logging.INFO)
2625

2726
__all__ = ['supernet']
2827

29-
WEIGHT_LAYER = ['conv', 'linear']
28+
WEIGHT_LAYER = ['conv', 'linear', 'embedding']
3029

3130

3231
### TODO: add decorator
@@ -45,7 +44,7 @@ def convert(self, model):
4544
cur_channel = None
4645
for idx, layer in enumerate(model):
4746
cls_name = layer.__class__.__name__.lower()
48-
if 'conv' in cls_name or 'linear' in cls_name:
47+
if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
4948
weight_layer_count += 1
5049
last_weight_layer_idx = idx
5150
if first_weight_layer_idx == -1:
@@ -63,7 +62,7 @@ def convert(self, model):
6362

6463
new_attr_name = [
6564
'_stride', '_dilation', '_groups', '_param_attr',
66-
'_bias_attr', '_use_cudnn', '_act', '_dtype'
65+
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_padding'
6766
]
6867

6968
new_attr_dict = dict()
@@ -179,6 +178,8 @@ def convert(self, model):
179178
layer._parameters['weight'].shape[0])
180179
elif self.context.channel:
181180
new_attr_dict['num_channels'] = max(cur_channel)
181+
else:
182+
new_attr_dict['num_channels'] = attr_dict['_num_channels']
182183

183184
for attr in new_attr_name:
184185
new_attr_dict[attr[1:]] = attr_dict[attr]
@@ -196,7 +197,8 @@ def convert(self, model):
196197

197198
new_attr_name = [
198199
'_stride', '_dilation', '_groups', '_param_attr',
199-
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size'
200+
'_padding', '_bias_attr', '_use_cudnn', '_act', '_dtype',
201+
'_output_size'
200202
]
201203
assert attr_dict[
202204
'_filter_size'] != None, "Conv2DTranspose only support filter size != None now"
@@ -371,6 +373,8 @@ def convert(self, model):
371373
layer._parameters['scale'].shape[0])
372374
elif self.context.channel:
373375
new_attr_dict['num_channels'] = max(cur_channel)
376+
else:
377+
new_attr_dict['num_channels'] = attr_dict['_num_channels']
374378

375379
for attr in new_attr_name:
376380
new_attr_dict[attr[1:]] = attr_dict[attr]
@@ -380,6 +384,76 @@ def convert(self, model):
380384
layer = SuperInstanceNorm(**new_attr_dict)
381385
model[idx] = layer
382386

387+
elif isinstance(layer, LayerNorm) and (
388+
getattr(self.context, 'expand', None) != None or
389+
getattr(self.context, 'channel', None) != None):
390+
### TODO(ceci3): fix when normalized_shape != last_dim_of_input
391+
if idx > last_weight_layer_idx:
392+
continue
393+
394+
attr_dict = layer.__dict__
395+
new_attr_name = [
396+
'_scale', '_shift', '_param_attr', '_bias_attr', '_act',
397+
'_dtype', '_epsilon'
398+
]
399+
new_attr_dict = dict()
400+
if self.context.expand:
401+
new_attr_dict[
402+
'normalized_shape'] = self.context.expand * int(
403+
attr_dict['_normalized_shape'][0])
404+
elif self.context.channel:
405+
new_attr_dict['normalized_shape'] = max(cur_channel)
406+
else:
407+
new_attr_dict['normalized_shape'] = attr_dict[
408+
'_normalized_shape']
409+
410+
for attr in new_attr_name:
411+
new_attr_dict[attr[1:]] = attr_dict[attr]
412+
413+
del layer, attr_dict
414+
layer = SuperLayerNorm(**new_attr_dict)
415+
model[idx] = layer
416+
417+
elif isinstance(layer, Embedding) and (
418+
getattr(self.context, 'expand', None) != None or
419+
getattr(self.context, 'channel', None) != None):
420+
attr_dict = layer.__dict__
421+
key = attr_dict['_full_name']
422+
new_attr_name = [
423+
'_is_sparse', '_is_distributed', '_padding_idx',
424+
'_param_attr', '_dtype'
425+
]
426+
427+
new_attr_dict = dict()
428+
new_attr_dict['candidate_config'] = dict()
429+
bef_size = attr_dict['_size']
430+
if self.context.expand:
431+
new_attr_dict['size'] = [
432+
bef_size[0], self.context.expand * bef_size[1]
433+
]
434+
new_attr_dict['candidate_config'].update({
435+
'expand_ratio': self.context.expand_ratio
436+
})
437+
438+
elif self.context.channel:
439+
cur_channel = self.context.channel[0]
440+
self.context.channel = self.context.channel[1:]
441+
new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
442+
new_attr_dict['candidate_config'].update({
443+
'channel': cur_channel
444+
})
445+
pre_channel = cur_channel
446+
else:
447+
new_attr_dict['size'] = bef_size
448+
449+
for attr in new_attr_name:
450+
new_attr_dict[attr[1:]] = attr_dict[attr]
451+
452+
del layer, attr_dict
453+
454+
layer = Block(SuperEmbedding(**new_attr_dict), key=key)
455+
model[idx] = layer
456+
383457
return model
384458

385459

0 commit comments

Comments
 (0)