@@ -173,18 +173,6 @@ def get_inputs(self):
173
173
return (torch .randn (2 , 2 , 4 , 4 ),)
174
174
175
175
176
- class Conv2dDQ (torch .nn .Module ):
177
- def __init__ (self ):
178
- super ().__init__ ()
179
- self .conv = torch .nn .Conv2d (in_channels = 3 , out_channels = 10 , kernel_size = 3 )
180
-
181
- def forward (self , x ):
182
- return self .conv (x )
183
-
184
- def get_inputs (self ):
185
- return (torch .randn (1 , 3 , 8 , 8 ),)
186
-
187
-
188
176
class Conv2dDQSeq (torch .nn .Module ):
189
177
def __init__ (self ):
190
178
super ().__init__ ()
@@ -210,7 +198,7 @@ def __init__(self):
210
198
in_channels = 3 , out_channels = 8 , kernel_size = 3 , padding = 1
211
199
)
212
200
self .second = torch .nn .Conv2d (
213
- in_channels = 3 , out_channels = 10 , kernel_size = 3 , padding = 1
201
+ in_channels = 3 , out_channels = 8 , kernel_size = 3 , padding = 1
214
202
)
215
203
216
204
def forward (self , x ):
@@ -785,13 +773,24 @@ def forward(self, x):
785
773
)
786
774
787
775
def test_dq_conv2d (self ) -> None :
788
- model = Conv2dDQ ()
776
+ model = Conv2d (
777
+ in_channels = 3 ,
778
+ out_channels = 10 ,
779
+ kernel_size = (3 , 3 ),
780
+ stride = (1 , 1 ),
781
+ padding = (0 , 0 ),
782
+ batches = 1 ,
783
+ width = 8 ,
784
+ height = 8 ,
785
+ )
789
786
self ._test_dq (model )
790
787
791
788
def test_dq_conv2d_seq (self ) -> None :
792
789
model = Conv2dDQSeq ()
793
- self ._test_dq (model , conv_count = 2 )
790
+ conv_count = sum (1 for m in model .modules () if type (m ) is torch .nn .Conv2d )
791
+ self ._test_dq (model , conv_count )
794
792
795
793
def test_dq_conv2d_parallel (self ) -> None :
796
794
model = Conv2dDQParallel ()
797
- self ._test_dq (model , conv_count = 2 )
795
+ conv_count = sum (1 for m in model .modules () if type (m ) is torch .nn .Conv2d )
796
+ self ._test_dq (model , conv_count )
0 commit comments