Skip to content

Commit cbf5d3f

Browse files
5574 Fix MBConvBlock issue in 3d (#6672)
Fixes #5574 . ### Description This PR is used to fix the `MBConvBlock` issue when `spatial_dims` is not 2. The PR follows the work in: #5695 As it was not updated after review. Thanks @swilson314 for posting the issue! ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]>
1 parent 9e14615 commit cbf5d3f

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

monai/networks/nets/efficientnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def __init__(
163163
self._se_adaptpool = adaptivepool_type(1)
164164
num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
165165
self._se_reduce = conv_type(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
166-
self._se_reduce_padding = _make_same_padder(self._se_reduce, [1, 1])
166+
self._se_reduce_padding = _make_same_padder(self._se_reduce, [1] * spatial_dims)
167167
self._se_expand = conv_type(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
168-
self._se_expand_padding = _make_same_padder(self._se_expand, [1, 1])
168+
self._se_expand_padding = _make_same_padder(self._se_expand, [1] * spatial_dims)
169169

170170
# Pointwise convolution phase
171171
final_oup = out_channels

tests/test_vitautoenc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@
6565

6666

6767
class TestPatchEmbeddingBlock(unittest.TestCase):
68+
def setUp(self):
69+
self.threads = torch.get_num_threads()
70+
torch.set_num_threads(4)
71+
72+
def tearDown(self):
73+
torch.set_num_threads(self.threads)
74+
6875
@parameterized.expand(TEST_CASE_Vitautoenc)
6976
@skip_if_windows
7077
def test_shape(self, input_param, input_shape, expected_shape):

tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def skip_if_downloading_fails():
153153
"md5 check",
154154
"limit", # HTTP Error 503: Egress is over the account limit
155155
"authenticate",
156+
"timed out", # urlopen error [Errno 110] Connection timed out
156157
)
157158
):
158159
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download

0 commit comments

Comments
 (0)