@@ -1488,6 +1488,7 @@ def test_convolution1d(
14881488 padding = padding ,
14891489 dilation = dilation ,
14901490 bias = bias ,
1491+ groups = groups ,
14911492 )
14921493 self .run_compare_torch (
14931494 (1 , in_channels , length ),
@@ -1557,6 +1558,7 @@ def test_convolution2d(
15571558 padding = padding ,
15581559 dilation = dilation ,
15591560 bias = bias ,
1561+ groups = groups ,
15601562 )
15611563 self .run_compare_torch (
15621564 (1 , in_channels , height , width ),
@@ -1628,6 +1630,7 @@ def test_convolution3d(
16281630 padding = padding ,
16291631 dilation = dilation ,
16301632 bias = bias ,
1633+ groups = groups ,
16311634 )
16321635 self .run_compare_torch (
16331636 (1 , in_channels , depth , height , width ),
@@ -1687,7 +1690,7 @@ def test_convolution1d(
16871690 class DynamicConv (nn .Module ):
16881691 def forward (self , input_data , weights ):
16891692 return nn .functional .conv1d (
1690- input_data , weights , stride = stride , padding = padding
1693+ input_data , weights , stride = stride , padding = padding , groups = groups
16911694 )
16921695
16931696 model = DynamicConv ()
@@ -1754,7 +1757,7 @@ def test_convolution2d(
17541757 class DynamicConv (nn .Module ):
17551758 def forward (self , input_data , weights ):
17561759 return nn .functional .conv2d (
1757- input_data , weights , stride = stride , padding = padding
1760+ input_data , weights , stride = stride , padding = padding , groups = groups
17581761 )
17591762
17601763 model = DynamicConv ()
@@ -1890,6 +1893,7 @@ def test_convolution_transpose2d(
18901893 stride = stride ,
18911894 padding = padding ,
18921895 dilation = dilation ,
1896+ groups = groups ,
18931897 )
18941898 self .run_compare_torch (
18951899 (1 , in_channels , height , width ),
@@ -2016,6 +2020,7 @@ def test_convolution_transpose2d_output_padding(
20162020 padding = padding ,
20172021 dilation = dilation ,
20182022 output_padding = output_padding ,
2023+ groups = groups ,
20192024 )
20202025 self .run_compare_torch (
20212026 (1 , in_channels , height , width ),
@@ -2099,6 +2104,7 @@ def test_convolution_transpose3d(
20992104 stride ,
21002105 padding ,
21012106 dilation ,
2107+ groups = 1 ,
21022108 ):
21032109 model = nn .ConvTranspose3d (
21042110 in_channels = in_channels ,
0 commit comments