@@ -69,12 +69,10 @@ def _lazy_init(self, in_dim: nn.Dim):
69
69
70
70
def _call_nd1 (self , source : nn .Tensor , * ,
71
71
in_dim : Optional [nn .Dim ] = None ,
72
- in_spatial_dim : nn .Dim ,
73
- out_spatial_dim : Optional [nn .Dim ] = None ) -> Tuple [nn .Tensor , nn .Dim ]:
72
+ in_spatial_dim : nn .Dim ) -> Tuple [nn .Tensor , nn .Dim ]:
74
73
assert self .nd == 1
75
74
out , (out_spatial_dim ,) = self .__class__ .__base__ .__call__ (
76
- self , source , in_dim = in_dim , in_spatial_dims = [in_spatial_dim ],
77
- out_spatial_dims = [out_spatial_dim ] if out_spatial_dim else None )
75
+ self , source , in_dim = in_dim , in_spatial_dims = [in_spatial_dim ])
78
76
return out , out_spatial_dim
79
77
80
78
@@ -119,21 +117,19 @@ def __init__(self,
119
117
@nn .scoped
120
118
def __call__ (self , source : nn .Tensor , * ,
121
119
in_dim : Optional [nn .Dim ] = None ,
122
- in_spatial_dims : Sequence [nn .Dim ],
123
- out_spatial_dims : Optional [Sequence [nn .Dim ]] = None
120
+ in_spatial_dims : Sequence [nn .Dim ]
124
121
) -> Tuple [nn .Tensor , Sequence [nn .Dim ]]:
125
122
source = nn .check_in_feature_dim_lazy_init (source , in_dim , self .in_dim , self ._lazy_init )
126
123
for in_spatial_dim in in_spatial_dims :
127
124
if in_spatial_dim not in source .shape :
128
125
raise ValueError (f"{ self } : source { source } does not have spatial dim { in_spatial_dim } " )
129
- if not out_spatial_dims :
130
- out_spatial_dims = _default_out_spatial_dims (
131
- description_prefix = nn .NameCtx .current_ctx ().layer_abs_name_scope ,
132
- in_spatial_dims = in_spatial_dims ,
133
- filter_size = [d .dimension for d in self .filter_size ],
134
- strides = 1 if not self .strides else self .strides ,
135
- dilation_rate = 1 if not self .dilation_rate else self .dilation_rate ,
136
- padding = self .padding )
126
+ out_spatial_dims = _default_out_spatial_dims (
127
+ description_prefix = nn .NameCtx .current_ctx ().layer_abs_name_scope ,
128
+ in_spatial_dims = in_spatial_dims ,
129
+ filter_size = [d .dimension for d in self .filter_size ],
130
+ strides = 1 if not self .strides else self .strides ,
131
+ dilation_rate = 1 if not self .dilation_rate else self .dilation_rate ,
132
+ padding = self .padding )
137
133
layer_dict = {
138
134
"class" : "conv" , "from" : source ,
139
135
"in_dim" : self .in_dim , "in_spatial_dims" : in_spatial_dims ,
@@ -242,18 +238,16 @@ def __init__(self,
242
238
@nn .scoped
243
239
def __call__ (self , source : nn .Tensor , * ,
244
240
in_dim : Optional [nn .Dim ] = None ,
245
- in_spatial_dims : Sequence [nn .Dim ],
246
- out_spatial_dims : Optional [Sequence [nn .Dim ]] = None
241
+ in_spatial_dims : Sequence [nn .Dim ]
247
242
) -> Tuple [nn .Tensor , Sequence [nn .Dim ]]:
248
243
source = nn .check_in_feature_dim_lazy_init (source , in_dim , self .in_dim , self ._lazy_init )
249
- if not out_spatial_dims :
250
- out_spatial_dims = [
251
- nn .SpatialDim (f"{ nn .NameCtx .current_ctx ().layer_abs_name_scope } :out-spatial-dim{ i } " )
252
- for i , s in enumerate (self .filter_size )]
253
- for i in range (len (self .filter_size )):
254
- s = self .filter_size [i ].dimension if not self .strides else self .strides [i ]
255
- if self .filter_size [i ].dimension == s == 1 or (s == 1 and self .padding .lower () == "same" ):
256
- out_spatial_dims [i ] = in_spatial_dims [i ]
244
+ out_spatial_dims = [
245
+ nn .SpatialDim (f"{ nn .NameCtx .current_ctx ().layer_abs_name_scope } :out-spatial-dim{ i } " )
246
+ for i , s in enumerate (self .filter_size )]
247
+ for i in range (len (self .filter_size )):
248
+ s = self .filter_size [i ].dimension if not self .strides else self .strides [i ]
249
+ if self .filter_size [i ].dimension == s == 1 or (s == 1 and self .padding .lower () == "same" ):
250
+ out_spatial_dims [i ] = in_spatial_dims [i ]
257
251
layer_dict = {
258
252
"class" : "transposed_conv" , "from" : source ,
259
253
"in_dim" : self .in_dim , "in_spatial_dims" : in_spatial_dims ,
0 commit comments