Skip to content

Commit aba4232

Browse files
committed
LayerBase base out data, fixes for Data.sparse_dim
1 parent 208720b commit aba4232

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

returnn/tf/layers/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,10 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
324324
if sources_data:
325325
out_type.setdefault("batch_dim_axis", sources_data.batch_dim_axis)
326326
out_type.setdefault("time_dim_axis", sources_data.time_dim_axis)
327-
if not out_type.get("sparse", False) and sources_data.feature_dim_axis_or_unspecified is not NotSpecified:
327+
if (
328+
not out_type.get("sparse", False) and
329+
not out_type.get("sparse_dim", None) and
330+
sources_data.feature_dim_axis_or_unspecified is not NotSpecified):
328331
if sources_data.feature_dim_axis_or_unspecified is not None:
329332
out_type.setdefault("feature_dim_axis", sources_data.feature_dim_axis_or_unspecified)
330333
else: # None
@@ -334,7 +337,7 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
334337
out_type.setdefault("time_dim_axis", None)
335338
if "shape" not in out_type and "dim_tags" not in out_type:
336339
if sources_data:
337-
if out_type.get("sparse", False):
340+
if out_type.get("sparse", False) or out_type.get("sparse_dim", None):
338341
out_type["dim_tags"] = sources_data.dim_tags_sparse
339342
else: # not sparse
340343
feature_dim_axis = out_type.get("feature_dim_axis", NotSpecified)

0 commit comments

Comments
 (0)