Skip to content

Commit 0c2cbac

Browse files
committed
Change Neureka conv parsers to inherit from Conv2d parser
1 parent 6a7fd3d commit 0c2cbac

File tree

1 file changed

+39
-88
lines changed

1 file changed

+39
-88
lines changed

Deeploy/Targets/Neureka/Parsers.py

+39-88
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,16 @@
2828
import onnx_graphsurgeon as gs
2929

3030
from Deeploy.DeeployTypes import NetworkContext
31-
from Deeploy.Targets.Generic.Parsers import ConvParser, RQSParserInterface
31+
from Deeploy.Targets.Generic.Parsers import Conv2DParser, ConvParser, RQSParserInterface
3232

3333

34-
class NeurekaConv2DBaseParser(ConvParser):
35-
36-
def __init__(self, noBiasHoisting: bool = True):
37-
super().__init__(noBiasHoisting)
34+
class NeurekaConv2DBaseParser(Conv2DParser):
3835

3936
def parseNode(self, node: gs.Node) -> bool:
4037
if not super().parseNode(node):
4138
return False
4239

4340
if not all([
44-
len(node.attrs['pads']) == 4,
4541
# No dilation support
4642
self.operatorRepresentation['dilations'] == [1, 1],
4743
# Channels have to be last
@@ -51,16 +47,6 @@ def parseNode(self, node: gs.Node) -> bool:
5147
]):
5248
return False
5349

54-
self.operatorRepresentation['dim_kernel_x'] = int(self.operatorRepresentation['kernel_shape'][0])
55-
self.operatorRepresentation['dim_kernel_y'] = int(self.operatorRepresentation['kernel_shape'][1])
56-
self.operatorRepresentation['dilation_x'] = int(self.operatorRepresentation['dilations'][0])
57-
self.operatorRepresentation['dilation_y'] = int(self.operatorRepresentation['dilations'][1])
58-
self.operatorRepresentation['padding_x'] = int(self.operatorRepresentation['pads'][0])
59-
self.operatorRepresentation['padding_y'] = int(self.operatorRepresentation['pads'][1])
60-
self.operatorRepresentation['stride_x'] = int(self.operatorRepresentation['strides'][0])
61-
self.operatorRepresentation['stride_y'] = int(self.operatorRepresentation['strides'][1])
62-
self.operatorRepresentation['bias_shift'] = int(0)
63-
self.operatorRepresentation['out_shift'] = int(0)
6450
self.operatorRepresentation['padding_y_top'] = int(self.operatorRepresentation['pads'][0])
6551
self.operatorRepresentation['padding_x_left'] = int(self.operatorRepresentation['pads'][1])
6652
self.operatorRepresentation['padding_y_bottom'] = int(self.operatorRepresentation['pads'][2])
@@ -73,31 +59,36 @@ def parseNodeCtxt(self,
7359
ctxt: NetworkContext,
7460
node: gs.Node,
7561
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
76-
77-
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
62+
# LMACAN: Cannot reuse the Conv2DParser's parserNodeCtxt because it requires the weight shape
63+
# to be of length 4 whereas neureka does a specific weight encoding so the shape
64+
# ends up being equal to 3.
65+
newCtxt, ret = ConvParser.parseNodeCtxt(self, ctxt, node, channels_first)
7866

7967
if not ret:
8068
return ctxt, False
8169

70+
# LMACAN: c/p of Conv2DParser's parserNodeCtxt but with a different weight shape check
71+
# and enforcing that the channels_first is false
8272
data_in = newCtxt.lookup(self.operatorRepresentation['data_in'])
8373
data_out = newCtxt.lookup(self.operatorRepresentation['data_out'])
8474
weight = newCtxt.lookup(self.operatorRepresentation['weight'])
8575

76+
if not all([
77+
channels_first == False,
78+
len(data_in.shape) == 4,
79+
# LMACAN: weight shape should be equal to 3 because we have to do the neureka's
80+
# special weight encoding
81+
len(weight.shape) == 3,
82+
]):
83+
return newCtxt, False
84+
8685
self.operatorRepresentation['batch'] = data_in.shape[0]
87-
if channels_first:
88-
self.operatorRepresentation['ch_im_in'] = data_in.shape[1]
89-
self.operatorRepresentation['dim_im_in_x'] = data_in.shape[2]
90-
self.operatorRepresentation['dim_im_in_y'] = data_in.shape[3]
91-
self.operatorRepresentation['ch_im_out'] = data_out.shape[1]
92-
self.operatorRepresentation['dim_im_out_x'] = data_out.shape[2]
93-
self.operatorRepresentation['dim_im_out_y'] = data_out.shape[3]
94-
else:
95-
self.operatorRepresentation['ch_im_in'] = data_in.shape[3]
96-
self.operatorRepresentation['dim_im_in_x'] = data_in.shape[1]
97-
self.operatorRepresentation['dim_im_in_y'] = data_in.shape[2]
98-
self.operatorRepresentation['ch_im_out'] = data_out.shape[3]
99-
self.operatorRepresentation['dim_im_out_x'] = data_out.shape[1]
100-
self.operatorRepresentation['dim_im_out_y'] = data_out.shape[2]
86+
self.operatorRepresentation['dim_im_in_x'] = data_in.shape[1]
87+
self.operatorRepresentation['dim_im_in_y'] = data_in.shape[2]
88+
self.operatorRepresentation['ch_im_in'] = data_in.shape[3]
89+
self.operatorRepresentation['dim_im_out_x'] = data_out.shape[1]
90+
self.operatorRepresentation['dim_im_out_y'] = data_out.shape[2]
91+
self.operatorRepresentation['ch_im_out'] = data_out.shape[3]
10192

10293
# No requantization
10394
self.operatorRepresentation['mul'] = 'NULL'
@@ -113,28 +104,18 @@ def parseNode(self, node: gs.Node) -> bool:
113104
if not super().parseNode(node):
114105
return False
115106

116-
if not self.operatorRepresentation['kernel_shape'] == [3, 3]:
117-
return False
118-
if self.operatorRepresentation['group'] == 1:
107+
ch_im_out = node.inputs[1].shape[0]
108+
ch_im_in = node.inputs[1].shape[1]
109+
110+
if not all([
111+
self.operatorRepresentation['kernel_shape'] == [3, 3],
112+
self.operatorRepresentation['group'] == ch_im_out,
113+
self.operatorRepresentation['group'] == ch_im_in,
114+
]):
119115
return False
120116

121117
return True
122118

123-
def parseNodeCtxt(self,
124-
ctxt: NetworkContext,
125-
node: gs.Node,
126-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
127-
128-
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
129-
130-
data_in = ctxt.lookup(self.operatorRepresentation['data_in'])
131-
weight = ctxt.lookup(self.operatorRepresentation['weight'])
132-
133-
if len(data_in.shape) != 4 or len(weight.shape) != 4:
134-
return ctxt, False
135-
136-
return newCtxt, True
137-
138119

139120
class NeurekaRQSDWConv2DParser(NeurekaDWConv2DParser, RQSParserInterface):
140121

@@ -168,29 +149,14 @@ def parseNode(self, node: gs.Node) -> bool:
168149
if not super().parseNode(node):
169150
return False
170151

171-
if not self.operatorRepresentation['kernel_shape'] == [1, 1]:
152+
if not all([
153+
self.operatorRepresentation['kernel_shape'] == [1, 1],
154+
self.operatorRepresentation['group'] == 1,
155+
]):
172156
return False
173157

174-
# if not self.operatorRepresentation['strides'] == [1, 1]:
175-
# return False
176-
177158
return True
178159

179-
def parseNodeCtxt(self,
180-
ctxt: NetworkContext,
181-
node: gs.Node,
182-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
183-
184-
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
185-
186-
data_in = newCtxt.lookup(self.operatorRepresentation['data_in'])
187-
weight = newCtxt.lookup(self.operatorRepresentation['weight'])
188-
189-
if len(data_in.shape) != 4 or len(weight.shape) != 3:
190-
return ctxt, False
191-
192-
return newCtxt, True
193-
194160

195161
class NeurekaRQSPWConv2DParser(NeurekaPWConv2DParser, RQSParserInterface):
196162

@@ -223,29 +189,14 @@ def parseNode(self, node: gs.Node) -> bool:
223189
if not super().parseNode(node):
224190
return False
225191

226-
if not self.operatorRepresentation['kernel_shape'] == [3, 3]:
227-
return False
228-
229-
if not self.operatorRepresentation['group'] == 1:
192+
if not all([
193+
self.operatorRepresentation['kernel_shape'] == [3, 3],
194+
self.operatorRepresentation['group'] == 1,
195+
]):
230196
return False
231197

232198
return True
233199

234-
def parseNodeCtxt(self,
235-
ctxt: NetworkContext,
236-
node: gs.Node,
237-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
238-
239-
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
240-
241-
data_in = newCtxt.lookup(self.operatorRepresentation['data_in'])
242-
weight = newCtxt.lookup(self.operatorRepresentation['weight'])
243-
244-
if len(data_in.shape) != 4 or len(weight.shape) != 4:
245-
return ctxt, False
246-
247-
return newCtxt, True
248-
249200

250201
class NeurekaRQSDenseConv2DParser(NeurekaDenseConv2DParser, RQSParserInterface):
251202

0 commit comments

Comments
 (0)