4
4
import logging
5
5
import numpy as np
6
6
7
- from chainer import links as L
8
-
9
7
from chainerpruner .rebuild import calc_pruning_connection
8
+ from chainerpruner .utils import named_modules
10
9
11
10
logger = logging .getLogger (__name__ )
12
11
@@ -17,10 +16,11 @@ class Mask():
17
16
18
17
def __init__ (self , model , graph , target_layers , mask_layer = None ):
19
18
self .model = model
20
- self ._model_dict = {name : link for name , link in model . namedlinks ( )}
19
+ self ._model_dict = {name : link for name , link in named_modules ( model )}
21
20
self .graph = graph
22
21
self .target_layers = target_layers
23
22
self .logger = logger
23
+ self ._is_chainer = graph .is_chainer
24
24
self .pruning_connection_info = calc_pruning_connection (graph )
25
25
self .masks = dict ()
26
26
self ._mask_layer = mask_layer
@@ -31,6 +31,9 @@ def __init__(self, model, graph, target_layers, mask_layer=None):
31
31
if mask_layer not in cand_mask_layer :
32
32
raise AttributeError ('mask_layer is expected which {}' .format (cand_mask_layer ))
33
33
34
+ def is_chainer (self ):
35
+ return self ._is_chainer
36
+
34
37
def get_filter_norm (self , mask ):
35
38
"""get mask for pruning
36
39
@@ -58,11 +61,44 @@ def get_thresholds(self, name, mask):
58
61
raise NotImplementedError ()
59
62
60
63
def _get_mask (self , name ):
64
+ """
65
+
66
+ Args:
67
+ name:
68
+
69
+ Returns:
70
+ (NDArray, ndarray): (conv-weight, mask-tensor)
71
+ conv-weight: (oc, ic, k, k) kernel order
72
+ mask-tensor: (oc, ic, k, k) or (oc, 1, 1, 1)
73
+
74
+ """
75
+ if self .is_chainer ():
76
+ return self ._get_mask_chainer (name )
77
+ else :
78
+ return self ._get_mask_pytorch (name )
79
+
80
+ def _get_mask_pytorch (self , name ):
81
+ from torch import nn
82
+ conv = self ._model_dict [name ]
83
+ if self ._mask_layer is None :
84
+ mask = conv .weight .data .clone ()
85
+ elif self ._mask_layer == 'batchnorm' :
86
+ # propagate mask bn: conv-bn
87
+ post_conv_bn_name = self .pruning_connection_info [name ][0 ]
88
+ bn = self ._model_dict [post_conv_bn_name ]
89
+ if not isinstance (bn , nn .BatchNorm2d ):
90
+ raise ValueError ('expected {}(Conv) -> {}(BatchNorm)' .format (name , post_conv_bn_name ))
91
+ mask = bn .weight .data .clone ()
92
+ mask = mask .reshape (- 1 , 1 , 1 , 1 ) # to mask conv weight (oc, ic, kh, kw)
93
+ return conv .weight .data , mask
94
+
95
+ def _get_mask_chainer (self , name ):
96
+ from chainer import links as L
61
97
conv = self ._model_dict [name ]
62
98
if self ._mask_layer is None :
63
99
mask = conv .W .array .copy ()
64
100
elif self ._mask_layer == 'batchnorm' :
65
- # conv-bn
101
+ # propagate mask bn: conv-bn
66
102
post_conv_bn_name = self .pruning_connection_info [name ][0 ]
67
103
bn = self ._model_dict [post_conv_bn_name ]
68
104
if not isinstance (bn , L .BatchNormalization ):
@@ -80,7 +116,8 @@ def __call__(self):
80
116
81
117
# get mask vector
82
118
target_weights = []
83
- for name , link in self .model .namedlinks (skipself = True ):
119
+ options = {'skipself' : True } if self .is_chainer () else dict ()
120
+ for name , link in named_modules (self .model , ** options ):
84
121
85
122
self .logger .debug ('name: %s' , name )
86
123
@@ -93,6 +130,7 @@ def __call__(self):
93
130
out_channels = mask .shape [0 ]
94
131
mask = self .get_filter_norm (mask )
95
132
if mask .shape != (out_channels , 1 , 1 , 1 ):
133
+ # expected (oc, ic, k, k) kernel order
96
134
raise RuntimeError ()
97
135
98
136
self .masks [name ] = mask
@@ -109,7 +147,10 @@ def __call__(self):
109
147
110
148
# apply mask
111
149
mask = mask_ >= threshold # 0: pruning, 1: non-pruning
112
- mask = mask .astype (np .float32 )
150
+ try :
151
+ mask = mask .astype (np .float32 )
152
+ except AttributeError :
153
+ mask = mask .type_as (target_weight )
113
154
114
155
info_ = {
115
156
'name' : name ,
0 commit comments