Skip to content

Commit d82e080

Browse files
committed
Use existing Conv2d class; get conv count from model
1 parent e336df6 commit d82e080

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

backends/xnnpack/test/ops/test_conv2d.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,6 @@ def get_inputs(self):
173173
return (torch.randn(2, 2, 4, 4),)
174174

175175

176-
class Conv2dDQ(torch.nn.Module):
177-
def __init__(self):
178-
super().__init__()
179-
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)
180-
181-
def forward(self, x):
182-
return self.conv(x)
183-
184-
def get_inputs(self):
185-
return (torch.randn(1, 3, 8, 8),)
186-
187-
188176
class Conv2dDQSeq(torch.nn.Module):
189177
def __init__(self):
190178
super().__init__()
@@ -210,7 +198,7 @@ def __init__(self):
210198
in_channels=3, out_channels=8, kernel_size=3, padding=1
211199
)
212200
self.second = torch.nn.Conv2d(
213-
in_channels=3, out_channels=10, kernel_size=3, padding=1
201+
in_channels=3, out_channels=8, kernel_size=3, padding=1
214202
)
215203

216204
def forward(self, x):
@@ -785,13 +773,24 @@ def forward(self, x):
785773
)
786774

787775
def test_dq_conv2d(self) -> None:
788-
model = Conv2dDQ()
776+
model = Conv2d(
777+
in_channels=3,
778+
out_channels=10,
779+
kernel_size=(3, 3),
780+
stride=(1, 1),
781+
padding=(0, 0),
782+
batches=1,
783+
width=8,
784+
height=8,
785+
)
789786
self._test_dq(model)
790787

791788
def test_dq_conv2d_seq(self) -> None:
792789
model = Conv2dDQSeq()
793-
self._test_dq(model, conv_count=2)
790+
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
791+
self._test_dq(model, conv_count)
794792

795793
def test_dq_conv2d_parallel(self) -> None:
796794
model = Conv2dDQParallel()
797-
self._test_dq(model, conv_count=2)
795+
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
796+
self._test_dq(model, conv_count)

0 commit comments

Comments
 (0)