16
16
import decorator
17
17
import logging
18
18
import paddle
19
- import paddle .fluid as fluid
20
- from paddle .fluid import framework
21
- from paddle .fluid .dygraph .nn import Conv2D , Conv2DTranspose , Linear , BatchNorm , InstanceNorm
19
+ import numbers
20
+ from paddle .fluid .dygraph .nn import Conv2D , Conv2DTranspose , Linear , BatchNorm , InstanceNorm , LayerNorm , Embedding
22
21
from .layers import *
23
22
from ...common import get_logger
24
23
25
24
_logger = get_logger (__name__ , level = logging .INFO )
26
25
27
26
__all__ = ['supernet' ]
28
27
29
- WEIGHT_LAYER = ['conv' , 'linear' ]
28
+ WEIGHT_LAYER = ['conv' , 'linear' , 'embedding' ]
30
29
31
30
32
31
### TODO: add decorator
@@ -45,7 +44,7 @@ def convert(self, model):
45
44
cur_channel = None
46
45
for idx , layer in enumerate (model ):
47
46
cls_name = layer .__class__ .__name__ .lower ()
48
- if 'conv' in cls_name or 'linear' in cls_name :
47
+ if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name :
49
48
weight_layer_count += 1
50
49
last_weight_layer_idx = idx
51
50
if first_weight_layer_idx == - 1 :
@@ -63,7 +62,7 @@ def convert(self, model):
63
62
64
63
new_attr_name = [
65
64
'_stride' , '_dilation' , '_groups' , '_param_attr' ,
66
- '_bias_attr' , '_use_cudnn' , '_act' , '_dtype'
65
+ '_bias_attr' , '_use_cudnn' , '_act' , '_dtype' , '_padding'
67
66
]
68
67
69
68
new_attr_dict = dict ()
@@ -179,6 +178,8 @@ def convert(self, model):
179
178
layer ._parameters ['weight' ].shape [0 ])
180
179
elif self .context .channel :
181
180
new_attr_dict ['num_channels' ] = max (cur_channel )
181
+ else :
182
+ new_attr_dict ['num_channels' ] = attr_dict ['_num_channels' ]
182
183
183
184
for attr in new_attr_name :
184
185
new_attr_dict [attr [1 :]] = attr_dict [attr ]
@@ -196,7 +197,8 @@ def convert(self, model):
196
197
197
198
new_attr_name = [
198
199
'_stride' , '_dilation' , '_groups' , '_param_attr' ,
199
- '_bias_attr' , '_use_cudnn' , '_act' , '_dtype' , '_output_size'
200
+ '_padding' , '_bias_attr' , '_use_cudnn' , '_act' , '_dtype' ,
201
+ '_output_size'
200
202
]
201
203
assert attr_dict [
202
204
'_filter_size' ] != None , "Conv2DTranspose only support filter size != None now"
@@ -371,6 +373,8 @@ def convert(self, model):
371
373
layer ._parameters ['scale' ].shape [0 ])
372
374
elif self .context .channel :
373
375
new_attr_dict ['num_channels' ] = max (cur_channel )
376
+ else :
377
+ new_attr_dict ['num_channels' ] = attr_dict ['_num_channels' ]
374
378
375
379
for attr in new_attr_name :
376
380
new_attr_dict [attr [1 :]] = attr_dict [attr ]
@@ -380,6 +384,76 @@ def convert(self, model):
380
384
layer = SuperInstanceNorm (** new_attr_dict )
381
385
model [idx ] = layer
382
386
387
+ elif isinstance (layer , LayerNorm ) and (
388
+ getattr (self .context , 'expand' , None ) != None or
389
+ getattr (self .context , 'channel' , None ) != None ):
390
+ ### TODO(ceci3): fix when normalized_shape != last_dim_of_input
391
+ if idx > last_weight_layer_idx :
392
+ continue
393
+
394
+ attr_dict = layer .__dict__
395
+ new_attr_name = [
396
+ '_scale' , '_shift' , '_param_attr' , '_bias_attr' , '_act' ,
397
+ '_dtype' , '_epsilon'
398
+ ]
399
+ new_attr_dict = dict ()
400
+ if self .context .expand :
401
+ new_attr_dict [
402
+ 'normalized_shape' ] = self .context .expand * int (
403
+ attr_dict ['_normalized_shape' ][0 ])
404
+ elif self .context .channel :
405
+ new_attr_dict ['normalized_shape' ] = max (cur_channel )
406
+ else :
407
+ new_attr_dict ['normalized_shape' ] = attr_dict [
408
+ '_normalized_shape' ]
409
+
410
+ for attr in new_attr_name :
411
+ new_attr_dict [attr [1 :]] = attr_dict [attr ]
412
+
413
+ del layer , attr_dict
414
+ layer = SuperLayerNorm (** new_attr_dict )
415
+ model [idx ] = layer
416
+
417
+ elif isinstance (layer , Embedding ) and (
418
+ getattr (self .context , 'expand' , None ) != None or
419
+ getattr (self .context , 'channel' , None ) != None ):
420
+ attr_dict = layer .__dict__
421
+ key = attr_dict ['_full_name' ]
422
+ new_attr_name = [
423
+ '_is_sparse' , '_is_distributed' , '_padding_idx' ,
424
+ '_param_attr' , '_dtype'
425
+ ]
426
+
427
+ new_attr_dict = dict ()
428
+ new_attr_dict ['candidate_config' ] = dict ()
429
+ bef_size = attr_dict ['_size' ]
430
+ if self .context .expand :
431
+ new_attr_dict ['size' ] = [
432
+ bef_size [0 ], self .context .expand * bef_size [1 ]
433
+ ]
434
+ new_attr_dict ['candidate_config' ].update ({
435
+ 'expand_ratio' : self .context .expand_ratio
436
+ })
437
+
438
+ elif self .context .channel :
439
+ cur_channel = self .context .channel [0 ]
440
+ self .context .channel = self .context .channel [1 :]
441
+ new_attr_dict ['size' ] = [bef_size [0 ], max (cur_channel )]
442
+ new_attr_dict ['candidate_config' ].update ({
443
+ 'channel' : cur_channel
444
+ })
445
+ pre_channel = cur_channel
446
+ else :
447
+ new_attr_dict ['size' ] = bef_size
448
+
449
+ for attr in new_attr_name :
450
+ new_attr_dict [attr [1 :]] = attr_dict [attr ]
451
+
452
+ del layer , attr_dict
453
+
454
+ layer = Block (SuperEmbedding (** new_attr_dict ), key = key )
455
+ model [idx ] = layer
456
+
383
457
return model
384
458
385
459
0 commit comments