Skip to content

Commit 77ee626

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Rename to functional and add 3d conv test
1 parent d65e753 commit 77ee626

File tree

1 file changed

+74
-8
lines changed

1 file changed

+74
-8
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,7 +1640,7 @@ def test_convolution3d(
16401640
)
16411641

16421642

1643-
class TestDynamicConv(TorchBaseTest):
1643+
class TestFunctionalConv(TorchBaseTest):
16441644
@pytest.mark.parametrize(
16451645
",".join(
16461646
[
@@ -1687,13 +1687,13 @@ def test_convolution1d(
16871687
padding,
16881688
groups=1,
16891689
):
1690-
class DynamicConv(nn.Module):
1690+
class FunctionalConv1D(nn.Module):
16911691
def forward(self, input_data, weights):
16921692
return nn.functional.conv1d(
16931693
input_data, weights, stride=stride, padding=padding, groups=groups
16941694
)
16951695

1696-
model = DynamicConv()
1696+
model = FunctionalConv1D()
16971697
input_shape = [
16981698
(1, in_channels, width),
16991699
(out_channels, int(in_channels / groups), kernel_size),
@@ -1754,13 +1754,13 @@ def test_convolution2d(
17541754
padding,
17551755
groups=1,
17561756
):
1757-
class DynamicConv(nn.Module):
1757+
class FunctionalConv2D(nn.Module):
17581758
def forward(self, input_data, weights):
17591759
return nn.functional.conv2d(
17601760
input_data, weights, stride=stride, padding=padding, groups=groups
17611761
)
17621762

1763-
model = DynamicConv()
1763+
model = FunctionalConv2D()
17641764

17651765
input_shape = [
17661766
(1, in_channels, height, width),
@@ -1774,21 +1774,19 @@ def forward(self, input_data, weights):
17741774
use_scripting=use_scripting
17751775
)
17761776

1777-
1778-
class TestConvTranspose(TorchBaseTest):
17791777
@pytest.mark.parametrize(
17801778
",".join(
17811779
[
17821780
"compute_unit",
17831781
"backend",
17841782
"use_scripting",
1783+
"height",
17851784
"width",
17861785
"in_channels",
17871786
"out_channels",
17881787
"kernel_size",
17891788
"stride",
17901789
"padding",
1791-
"dilation",
17921790
]
17931791
),
17941792
[
@@ -1797,6 +1795,74 @@ class TestConvTranspose(TorchBaseTest):
17971795
compute_units,
17981796
backends,
17991797
[True, False],
1798+
[
1799+
(5, 3, 2, 1, 1, 1, 2, 0),
1800+
(3, 3, 1, 1, 1, 1, 2, 1),
1801+
(4, 3, 3, 3, 3, 1, 2, 0),
1802+
(7, 3, 4, 3, 3, 1, 3, 0),
1803+
(5, 5, 3, 3, 3, 2, 1, 0),
1804+
(3, 5, 1, 3, 3, 1, 3, 0),
1805+
(3, 5, 4, 3, 3, 1, 3, 1),
1806+
(7, 5, 6, 3, 3, 2, 3, 1),
1807+
],
1808+
)
1809+
],
1810+
)
1811+
def test_convolution3d(
1812+
self,
1813+
compute_unit,
1814+
backend,
1815+
use_scripting,
1816+
height,
1817+
width,
1818+
in_channels,
1819+
out_channels,
1820+
kernel_size,
1821+
stride,
1822+
padding,
1823+
groups=1,
1824+
):
1825+
class FunctionalConv3D(nn.Module):
1826+
def forward(self, input_data, weights):
1827+
return nn.functional.conv3d(
1828+
input_data, weights, stride=stride, padding=padding, groups=groups
1829+
)
1830+
1831+
model = FunctionalConv3D()
1832+
1833+
input_shape = [
1834+
(1, in_channels, height, width),
1835+
(out_channels, int(in_channels / groups), kernel_size, kernel_size),
1836+
]
1837+
self.run_compare_torch(
1838+
input_shape,
1839+
model,
1840+
backend=backend,
1841+
compute_unit=compute_unit,
1842+
use_scripting=use_scripting
1843+
)
1844+
1845+
1846+
class TestConvTranspose(TorchBaseTest):
1847+
@pytest.mark.parametrize(
1848+
",".join(
1849+
[
1850+
"compute_unit",
1851+
"backend",
1852+
"width",
1853+
"in_channels",
1854+
"out_channels",
1855+
"kernel_size",
1856+
"stride",
1857+
"padding",
1858+
"dilation",
1859+
]
1860+
),
1861+
[
1862+
(compute_unit, backend, *param)
1863+
for compute_unit, backend, param in itertools.product(
1864+
compute_units,
1865+
backends,
18001866
[
18011867
(3, 1, 1, 1, 2, 0, 1),
18021868
(3, 1, 1, 1, 2, 1, 3),

0 commit comments

Comments
 (0)