@@ -298,7 +298,7 @@ def _base_get_out_data_from_opts(cls, network, name,
298
298
out_type = None , out_dim = None , n_out = NotSpecified ,
299
299
out_shape = None ,
300
300
target = None , _target_layers = None , size_target = None ,
301
- sources = (), loss = None ,
301
+ sources = (), in_dim = None , loss = None ,
302
302
** kwargs ):
303
303
"""
304
304
Called via BaseLayer.get_out_data_from_opts().
@@ -313,6 +313,7 @@ def _base_get_out_data_from_opts(cls, network, name,
313
313
:param dict[str,LayerBase]|None _target_layers: if target.startswith("layer:"), then this is target -> layer
314
314
:param str|None size_target:
315
315
:param list[LayerBase] sources:
316
+ :param DimensionTag|None in_dim:
316
317
:param Loss|None loss:
317
318
:param kwargs: remaining kwargs of self.__init__(), ignored here
318
319
:return: Data template (placeholder not set)
@@ -338,6 +339,15 @@ def _base_get_out_data_from_opts(cls, network, name,
338
339
if n_out is not NotSpecified :
339
340
assert out_type ["dim" ] == n_out
340
341
sources_data_list = [src .output for src in sources if src ]
342
+ if in_dim :
343
+ assert len (sources_data_list ) == 1
344
+ if sources_data_list [0 ].feature_dim_or_sparse_dim != in_dim :
345
+ # Allow to specify some in_dim which is not the feature dim.
346
+ # However, the follow-up code will expect it to be the feature dim, thus reassign it if possible.
347
+ assert in_dim in sources_data_list [0 ].dim_tags
348
+ axis = sources_data_list [0 ].get_axis_from_description (in_dim )
349
+ sources_data_list = [sources_data_list [0 ].copy ()]
350
+ sources_data_list [0 ].feature_dim_axis = axis
341
351
allow_broadcast_all_sources = NotSpecified
342
352
if "shape" in out_type or "dim_tags" in out_type or out_shape is not None :
343
353
allow_broadcast_all_sources = True
0 commit comments