Skip to content

Commit de1e559

Browse files
committed
test_DotLayer_mask_dyn_seq_after_softmax
1 parent b1dcb4e commit de1e559

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/test_TFNetworkLayer.py

+32
Original file line numberDiff line numberDiff line change
@@ -3247,6 +3247,38 @@ def test_DotLayer_mask_dyn_seq():
32473247
session.run(layer.output.placeholder, feed_dict=feed_dict)
32483248

32493249

3250+
def test_DotLayer_mask_dyn_seq_after_softmax():
3251+
batch = DimensionTag(kind=DimensionTag.Types.Batch, description="batch")
3252+
time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
3253+
feat1 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 1", dimension=3)
3254+
feat2 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 2", dimension=5)
3255+
config = Config({
3256+
"extern_data": {
3257+
"src1": {"dim_tags": [batch, time, feat1]},
3258+
"src2": {"dim_tags": [batch, time, feat2]},
3259+
},
3260+
"network": {
3261+
"sm1": {"class": "softmax_over_spatial", "from": "data:src1"},
3262+
"dot": {
3263+
"class": "dot", "from": ["sm1", "data:src2"], "is_output_layer": True,
3264+
"red1": time, "red2": time, "var1": feat1, "var2": feat2
3265+
},
3266+
},
3267+
"debug_print_layer_output_template": True,
3268+
})
3269+
3270+
with make_scope() as session:
3271+
net = TFNetwork(config=config)
3272+
net.construct_from_dict(config.typed_dict["network"])
3273+
layer = net.layers["dot"]
3274+
assert isinstance(layer, DotLayer)
3275+
assert layer.output.dim_tags == (batch, feat1, feat2)
3276+
assert layer._info_reduce_mask == "source-0-already-masked"
3277+
3278+
feed_dict = make_feed_dict(net.extern_data)
3279+
session.run(layer.output.placeholder, feed_dict=feed_dict)
3280+
3281+
32503282
def test_subnet_load_on_init():
32513283
import tempfile
32523284
model_tmp_dir = tempfile.mkdtemp("tmp-checkpoint")

0 commit comments

Comments
 (0)