Skip to content

Commit f4d062a

Browse files
committed
cleanup
#116
1 parent f4ecd97 commit f4d062a

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

nn/conv.py

+18-24
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,10 @@ def _lazy_init(self, in_dim: nn.Dim):
6969

7070
def _call_nd1(self, source: nn.Tensor, *,
7171
in_dim: Optional[nn.Dim] = None,
72-
in_spatial_dim: nn.Dim,
73-
out_spatial_dim: Optional[nn.Dim] = None) -> Tuple[nn.Tensor, nn.Dim]:
72+
in_spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
7473
assert self.nd == 1
7574
out, (out_spatial_dim,) = self.__class__.__base__.__call__(
76-
self, source, in_dim=in_dim, in_spatial_dims=[in_spatial_dim],
77-
out_spatial_dims=[out_spatial_dim] if out_spatial_dim else None)
75+
self, source, in_dim=in_dim, in_spatial_dims=[in_spatial_dim])
7876
return out, out_spatial_dim
7977

8078

@@ -119,21 +117,19 @@ def __init__(self,
119117
@nn.scoped
120118
def __call__(self, source: nn.Tensor, *,
121119
in_dim: Optional[nn.Dim] = None,
122-
in_spatial_dims: Sequence[nn.Dim],
123-
out_spatial_dims: Optional[Sequence[nn.Dim]] = None
120+
in_spatial_dims: Sequence[nn.Dim]
124121
) -> Tuple[nn.Tensor, Sequence[nn.Dim]]:
125122
source = nn.check_in_feature_dim_lazy_init(source, in_dim, self.in_dim, self._lazy_init)
126123
for in_spatial_dim in in_spatial_dims:
127124
if in_spatial_dim not in source.shape:
128125
raise ValueError(f"{self}: source {source} does not have spatial dim {in_spatial_dim}")
129-
if not out_spatial_dims:
130-
out_spatial_dims = _default_out_spatial_dims(
131-
description_prefix=nn.NameCtx.current_ctx().layer_abs_name_scope,
132-
in_spatial_dims=in_spatial_dims,
133-
filter_size=[d.dimension for d in self.filter_size],
134-
strides=1 if not self.strides else self.strides,
135-
dilation_rate=1 if not self.dilation_rate else self.dilation_rate,
136-
padding=self.padding)
126+
out_spatial_dims = _default_out_spatial_dims(
127+
description_prefix=nn.NameCtx.current_ctx().layer_abs_name_scope,
128+
in_spatial_dims=in_spatial_dims,
129+
filter_size=[d.dimension for d in self.filter_size],
130+
strides=1 if not self.strides else self.strides,
131+
dilation_rate=1 if not self.dilation_rate else self.dilation_rate,
132+
padding=self.padding)
137133
layer_dict = {
138134
"class": "conv", "from": source,
139135
"in_dim": self.in_dim, "in_spatial_dims": in_spatial_dims,
@@ -242,18 +238,16 @@ def __init__(self,
242238
@nn.scoped
243239
def __call__(self, source: nn.Tensor, *,
244240
in_dim: Optional[nn.Dim] = None,
245-
in_spatial_dims: Sequence[nn.Dim],
246-
out_spatial_dims: Optional[Sequence[nn.Dim]] = None
241+
in_spatial_dims: Sequence[nn.Dim]
247242
) -> Tuple[nn.Tensor, Sequence[nn.Dim]]:
248243
source = nn.check_in_feature_dim_lazy_init(source, in_dim, self.in_dim, self._lazy_init)
249-
if not out_spatial_dims:
250-
out_spatial_dims = [
251-
nn.SpatialDim(f"{nn.NameCtx.current_ctx().layer_abs_name_scope}:out-spatial-dim{i}")
252-
for i, s in enumerate(self.filter_size)]
253-
for i in range(len(self.filter_size)):
254-
s = self.filter_size[i].dimension if not self.strides else self.strides[i]
255-
if self.filter_size[i].dimension == s == 1 or (s == 1 and self.padding.lower() == "same"):
256-
out_spatial_dims[i] = in_spatial_dims[i]
244+
out_spatial_dims = [
245+
nn.SpatialDim(f"{nn.NameCtx.current_ctx().layer_abs_name_scope}:out-spatial-dim{i}")
246+
for i, s in enumerate(self.filter_size)]
247+
for i in range(len(self.filter_size)):
248+
s = self.filter_size[i].dimension if not self.strides else self.strides[i]
249+
if self.filter_size[i].dimension == s == 1 or (s == 1 and self.padding.lower() == "same"):
250+
out_spatial_dims[i] = in_spatial_dims[i]
257251
layer_dict = {
258252
"class": "transposed_conv", "from": source,
259253
"in_dim": self.in_dim, "in_spatial_dims": in_spatial_dims,

0 commit comments

Comments
 (0)