@@ -1640,7 +1640,7 @@ def test_convolution3d(
16401640 )
16411641
16421642
1643- class TestDynamicConv (TorchBaseTest ):
1643+ class TestFunctionalConv (TorchBaseTest ):
16441644 @pytest .mark .parametrize (
16451645 "," .join (
16461646 [
@@ -1687,13 +1687,13 @@ def test_convolution1d(
16871687 padding ,
16881688 groups = 1 ,
16891689 ):
1690- class DynamicConv (nn .Module ):
1690+ class FunctionalConv1D (nn .Module ):
16911691 def forward (self , input_data , weights ):
16921692 return nn .functional .conv1d (
16931693 input_data , weights , stride = stride , padding = padding , groups = groups
16941694 )
16951695
1696- model = DynamicConv ()
1696+ model = FunctionalConv1D ()
16971697 input_shape = [
16981698 (1 , in_channels , width ),
16991699 (out_channels , int (in_channels / groups ), kernel_size ),
@@ -1754,13 +1754,13 @@ def test_convolution2d(
17541754 padding ,
17551755 groups = 1 ,
17561756 ):
1757- class DynamicConv (nn .Module ):
1757+ class FunctionalConv2D (nn .Module ):
17581758 def forward (self , input_data , weights ):
17591759 return nn .functional .conv2d (
17601760 input_data , weights , stride = stride , padding = padding , groups = groups
17611761 )
17621762
1763- model = DynamicConv ()
1763+ model = FunctionalConv2D ()
17641764
17651765 input_shape = [
17661766 (1 , in_channels , height , width ),
@@ -1774,21 +1774,19 @@ def forward(self, input_data, weights):
17741774 use_scripting = use_scripting
17751775 )
17761776
1777-
1778- class TestConvTranspose (TorchBaseTest ):
17791777 @pytest .mark .parametrize (
17801778 "," .join (
17811779 [
17821780 "compute_unit" ,
17831781 "backend" ,
17841782 "use_scripting" ,
1783+ "height" ,
17851784 "width" ,
17861785 "in_channels" ,
17871786 "out_channels" ,
17881787 "kernel_size" ,
17891788 "stride" ,
17901789 "padding" ,
1791- "dilation" ,
17921790 ]
17931791 ),
17941792 [
@@ -1797,6 +1795,74 @@ class TestConvTranspose(TorchBaseTest):
17971795 compute_units ,
17981796 backends ,
17991797 [True , False ],
1798+ [
1799+ (5 , 3 , 2 , 1 , 1 , 1 , 2 , 0 ),
1800+ (3 , 3 , 1 , 1 , 1 , 1 , 2 , 1 ),
1801+ (4 , 3 , 3 , 3 , 3 , 1 , 2 , 0 ),
1802+ (7 , 3 , 4 , 3 , 3 , 1 , 3 , 0 ),
1803+ (5 , 5 , 3 , 3 , 3 , 2 , 1 , 0 ),
1804+ (3 , 5 , 1 , 3 , 3 , 1 , 3 , 0 ),
1805+ (3 , 5 , 4 , 3 , 3 , 1 , 3 , 1 ),
1806+ (7 , 5 , 6 , 3 , 3 , 2 , 3 , 1 ),
1807+ ],
1808+ )
1809+ ],
1810+ )
1811+ def test_convolution3d (
1812+ self ,
1813+ compute_unit ,
1814+ backend ,
1815+ use_scripting ,
1816+ height ,
1817+ width ,
1818+ in_channels ,
1819+ out_channels ,
1820+ kernel_size ,
1821+ stride ,
1822+ padding ,
1823+ groups = 1 ,
1824+ ):
1825+ class FunctionalConv3D (nn .Module ):
1826+ def forward (self , input_data , weights ):
1827+ return nn .functional .conv3d (
1828+ input_data , weights , stride = stride , padding = padding , groups = groups
1829+ )
1830+
1831+ model = FunctionalConv3D ()
1832+
1833+ input_shape = [
1834+ (1 , in_channels , height , width ),
1835+ (out_channels , int (in_channels / groups ), kernel_size , kernel_size ),
1836+ ]
1837+ self .run_compare_torch (
1838+ input_shape ,
1839+ model ,
1840+ backend = backend ,
1841+ compute_unit = compute_unit ,
1842+ use_scripting = use_scripting
1843+ )
1844+
1845+
1846+ class TestConvTranspose (TorchBaseTest ):
1847+ @pytest .mark .parametrize (
1848+ "," .join (
1849+ [
1850+ "compute_unit" ,
1851+ "backend" ,
1852+ "width" ,
1853+ "in_channels" ,
1854+ "out_channels" ,
1855+ "kernel_size" ,
1856+ "stride" ,
1857+ "padding" ,
1858+ "dilation" ,
1859+ ]
1860+ ),
1861+ [
1862+ (compute_unit , backend , * param )
1863+ for compute_unit , backend , param in itertools .product (
1864+ compute_units ,
1865+ backends ,
18001866 [
18011867 (3 , 1 , 1 , 1 , 2 , 0 , 1 ),
18021868 (3 , 1 , 1 , 1 , 2 , 1 , 3 ),
0 commit comments