28
28
import onnx_graphsurgeon as gs
29
29
30
30
from Deeploy .DeeployTypes import NetworkContext
31
- from Deeploy .Targets .Generic .Parsers import ConvParser , RQSParserInterface
31
+ from Deeploy .Targets .Generic .Parsers import Conv2DParser , ConvParser , RQSParserInterface
32
32
33
33
34
- class NeurekaConv2DBaseParser (ConvParser ):
35
-
36
- def __init__ (self , noBiasHoisting : bool = True ):
37
- super ().__init__ (noBiasHoisting )
34
+ class NeurekaConv2DBaseParser (Conv2DParser ):
38
35
39
36
def parseNode (self , node : gs .Node ) -> bool :
40
37
if not super ().parseNode (node ):
41
38
return False
42
39
43
40
if not all ([
44
- len (node .attrs ['pads' ]) == 4 ,
45
41
# No dilation support
46
42
self .operatorRepresentation ['dilations' ] == [1 , 1 ],
47
43
# Channels have to be last
@@ -51,16 +47,6 @@ def parseNode(self, node: gs.Node) -> bool:
51
47
]):
52
48
return False
53
49
54
- self .operatorRepresentation ['dim_kernel_x' ] = int (self .operatorRepresentation ['kernel_shape' ][0 ])
55
- self .operatorRepresentation ['dim_kernel_y' ] = int (self .operatorRepresentation ['kernel_shape' ][1 ])
56
- self .operatorRepresentation ['dilation_x' ] = int (self .operatorRepresentation ['dilations' ][0 ])
57
- self .operatorRepresentation ['dilation_y' ] = int (self .operatorRepresentation ['dilations' ][1 ])
58
- self .operatorRepresentation ['padding_x' ] = int (self .operatorRepresentation ['pads' ][0 ])
59
- self .operatorRepresentation ['padding_y' ] = int (self .operatorRepresentation ['pads' ][1 ])
60
- self .operatorRepresentation ['stride_x' ] = int (self .operatorRepresentation ['strides' ][0 ])
61
- self .operatorRepresentation ['stride_y' ] = int (self .operatorRepresentation ['strides' ][1 ])
62
- self .operatorRepresentation ['bias_shift' ] = int (0 )
63
- self .operatorRepresentation ['out_shift' ] = int (0 )
64
50
self .operatorRepresentation ['padding_y_top' ] = int (self .operatorRepresentation ['pads' ][0 ])
65
51
self .operatorRepresentation ['padding_x_left' ] = int (self .operatorRepresentation ['pads' ][1 ])
66
52
self .operatorRepresentation ['padding_y_bottom' ] = int (self .operatorRepresentation ['pads' ][2 ])
@@ -73,31 +59,36 @@ def parseNodeCtxt(self,
73
59
ctxt : NetworkContext ,
74
60
node : gs .Node ,
75
61
channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
76
-
77
- newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
62
+ # LMACAN: Cannot reuse the Conv2DParser's parserNodeCtxt because it requires the weight shape
63
+ # to be of length 4 whereas neureka does a specific weight encoding so the shape
64
+ # ends up being equal to 3.
65
+ newCtxt , ret = ConvParser .parseNodeCtxt (self , ctxt , node , channels_first )
78
66
79
67
if not ret :
80
68
return ctxt , False
81
69
70
+ # LMACAN: c/p of Conv2DParser's parserNodeCtxt but with a different weight shape check
71
+ # and enforcing that the channels_first is false
82
72
data_in = newCtxt .lookup (self .operatorRepresentation ['data_in' ])
83
73
data_out = newCtxt .lookup (self .operatorRepresentation ['data_out' ])
84
74
weight = newCtxt .lookup (self .operatorRepresentation ['weight' ])
85
75
76
+ if not all ([
77
+ channels_first == False ,
78
+ len (data_in .shape ) == 4 ,
79
+ # LMACAN: weight shape should be equal to 3 because we have to do the neureka's
80
+ # special weight encoding
81
+ len (weight .shape ) == 3 ,
82
+ ]):
83
+ return newCtxt , False
84
+
86
85
self .operatorRepresentation ['batch' ] = data_in .shape [0 ]
87
- if channels_first :
88
- self .operatorRepresentation ['ch_im_in' ] = data_in .shape [1 ]
89
- self .operatorRepresentation ['dim_im_in_x' ] = data_in .shape [2 ]
90
- self .operatorRepresentation ['dim_im_in_y' ] = data_in .shape [3 ]
91
- self .operatorRepresentation ['ch_im_out' ] = data_out .shape [1 ]
92
- self .operatorRepresentation ['dim_im_out_x' ] = data_out .shape [2 ]
93
- self .operatorRepresentation ['dim_im_out_y' ] = data_out .shape [3 ]
94
- else :
95
- self .operatorRepresentation ['ch_im_in' ] = data_in .shape [3 ]
96
- self .operatorRepresentation ['dim_im_in_x' ] = data_in .shape [1 ]
97
- self .operatorRepresentation ['dim_im_in_y' ] = data_in .shape [2 ]
98
- self .operatorRepresentation ['ch_im_out' ] = data_out .shape [3 ]
99
- self .operatorRepresentation ['dim_im_out_x' ] = data_out .shape [1 ]
100
- self .operatorRepresentation ['dim_im_out_y' ] = data_out .shape [2 ]
86
+ self .operatorRepresentation ['dim_im_in_x' ] = data_in .shape [1 ]
87
+ self .operatorRepresentation ['dim_im_in_y' ] = data_in .shape [2 ]
88
+ self .operatorRepresentation ['ch_im_in' ] = data_in .shape [3 ]
89
+ self .operatorRepresentation ['dim_im_out_x' ] = data_out .shape [1 ]
90
+ self .operatorRepresentation ['dim_im_out_y' ] = data_out .shape [2 ]
91
+ self .operatorRepresentation ['ch_im_out' ] = data_out .shape [3 ]
101
92
102
93
# No requantization
103
94
self .operatorRepresentation ['mul' ] = 'NULL'
@@ -113,28 +104,18 @@ def parseNode(self, node: gs.Node) -> bool:
113
104
if not super ().parseNode (node ):
114
105
return False
115
106
116
- if not self .operatorRepresentation ['kernel_shape' ] == [3 , 3 ]:
117
- return False
118
- if self .operatorRepresentation ['group' ] == 1 :
107
+ ch_im_out = node .inputs [1 ].shape [0 ]
108
+ ch_im_in = node .inputs [1 ].shape [1 ]
109
+
110
+ if not all ([
111
+ self .operatorRepresentation ['kernel_shape' ] == [3 , 3 ],
112
+ self .operatorRepresentation ['group' ] == ch_im_out ,
113
+ self .operatorRepresentation ['group' ] == ch_im_in ,
114
+ ]):
119
115
return False
120
116
121
117
return True
122
118
123
- def parseNodeCtxt (self ,
124
- ctxt : NetworkContext ,
125
- node : gs .Node ,
126
- channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
127
-
128
- newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
129
-
130
- data_in = ctxt .lookup (self .operatorRepresentation ['data_in' ])
131
- weight = ctxt .lookup (self .operatorRepresentation ['weight' ])
132
-
133
- if len (data_in .shape ) != 4 or len (weight .shape ) != 4 :
134
- return ctxt , False
135
-
136
- return newCtxt , True
137
-
138
119
139
120
class NeurekaRQSDWConv2DParser (NeurekaDWConv2DParser , RQSParserInterface ):
140
121
@@ -168,29 +149,14 @@ def parseNode(self, node: gs.Node) -> bool:
168
149
if not super ().parseNode (node ):
169
150
return False
170
151
171
- if not self .operatorRepresentation ['kernel_shape' ] == [1 , 1 ]:
152
+ if not all ([
153
+ self .operatorRepresentation ['kernel_shape' ] == [1 , 1 ],
154
+ self .operatorRepresentation ['group' ] == 1 ,
155
+ ]):
172
156
return False
173
157
174
- # if not self.operatorRepresentation['strides'] == [1, 1]:
175
- # return False
176
-
177
158
return True
178
159
179
- def parseNodeCtxt (self ,
180
- ctxt : NetworkContext ,
181
- node : gs .Node ,
182
- channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
183
-
184
- newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
185
-
186
- data_in = newCtxt .lookup (self .operatorRepresentation ['data_in' ])
187
- weight = newCtxt .lookup (self .operatorRepresentation ['weight' ])
188
-
189
- if len (data_in .shape ) != 4 or len (weight .shape ) != 3 :
190
- return ctxt , False
191
-
192
- return newCtxt , True
193
-
194
160
195
161
class NeurekaRQSPWConv2DParser (NeurekaPWConv2DParser , RQSParserInterface ):
196
162
@@ -223,29 +189,14 @@ def parseNode(self, node: gs.Node) -> bool:
223
189
if not super ().parseNode (node ):
224
190
return False
225
191
226
- if not self . operatorRepresentation [ 'kernel_shape' ] == [ 3 , 3 ]:
227
- return False
228
-
229
- if not self . operatorRepresentation [ 'group' ] == 1 :
192
+ if not all ([
193
+ self . operatorRepresentation [ 'kernel_shape' ] == [ 3 , 3 ],
194
+ self . operatorRepresentation [ 'group' ] == 1 ,
195
+ ]) :
230
196
return False
231
197
232
198
return True
233
199
234
- def parseNodeCtxt (self ,
235
- ctxt : NetworkContext ,
236
- node : gs .Node ,
237
- channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
238
-
239
- newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
240
-
241
- data_in = newCtxt .lookup (self .operatorRepresentation ['data_in' ])
242
- weight = newCtxt .lookup (self .operatorRepresentation ['weight' ])
243
-
244
- if len (data_in .shape ) != 4 or len (weight .shape ) != 4 :
245
- return ctxt , False
246
-
247
- return newCtxt , True
248
-
249
200
250
201
class NeurekaRQSDenseConv2DParser (NeurekaDenseConv2DParser , RQSParserInterface ):
251
202
0 commit comments