Skip to content

Commit 2337f50

Browse files
committed
Tests for self attention using CumConcatLayer
1 parent 210cd38 commit 2337f50

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

tests/test_TFNetworkRecLayer.py

+144
Original file line numberDiff line numberDiff line change
@@ -6550,6 +6550,150 @@ def test_RelativePositionalEncodingLayer():
65506550
print(out) # random...
65516551

65526552

6553+
def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
6554+
value_dim=64, dropout=0.0):
6555+
"""
6556+
Essentially this does
6557+
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6558+
"total_key_dim": num_heads * key_dim,
6559+
"n_out": num_heads * value_dim, "from": [input],
6560+
"attention_left_only": inside_rec_layer,
6561+
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
6562+
But using multiple layers.
6563+
"""
6564+
# Create (non-accumulated) query, key and value
6565+
d[output + '_qkv0'] = {
6566+
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
6567+
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
6568+
d[output + '_qkv'] = {
6569+
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
6570+
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
6571+
d[output + '_qkv_split'] = {
6572+
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim),
6573+
'from': [output + '_qkv']}
6574+
d[output + '_query'] = {
6575+
'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
6576+
d[output + '_key'] = {
6577+
'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
6578+
d[output + '_value'] = {
6579+
'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]
6580+
6581+
# Accumulate keys/values or rename the axis
6582+
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
6583+
key_axis = 'stag:' + key_dim_tag.description
6584+
if inside_rec_layer:
6585+
d[output + '_key_accum'] = {
6586+
'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k]
6587+
d[output + '_value_accum'] = {
6588+
'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v]
6589+
else:
6590+
d[output + '_key_accum'] = {
6591+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6592+
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
6593+
d[output + '_value_accum'] = {
6594+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6595+
'from': [output + '_value']} # [B,T|keys,n,F|d_v]
6596+
6597+
# Calculate the energies
6598+
d[output + '_energy'] = {
6599+
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
6600+
'red1': 'static:-1', 'red2': 'static:-1', 'common': ['B', 'static:0']} # [B,n,T?,T|rec-history]
6601+
6602+
d[output + '_weights'] = {
6603+
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
6604+
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
6605+
d[output + '_weights_drop'] = {
6606+
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
6607+
'dropout': dropout} # [B,n,T?,T|rec-history]
6608+
6609+
d[output + '_output'] = {
6610+
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
6611+
'red1': key_axis, 'red2': key_axis, 'common': ['B', query_axis, 'static:0']} # [B,n,T?,F|d_v]
6612+
d[output + '_att'] = {
6613+
'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]
6614+
6615+
6616+
def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
6617+
n_time = 13
6618+
num_heads, key_dim, value_dim = 2, 3, 3
6619+
for inside_rec_layer in [False, True]:
6620+
with make_scope() as session:
6621+
print('Testing inside_rec_layer=%s' % inside_rec_layer)
6622+
6623+
# build net dict
6624+
single_layer_net_dict = {
6625+
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
6626+
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer, 'is_output_layer': True} # [B,T,F]
6627+
if inside_rec_layer:
6628+
net_dict = {
6629+
"output": {
6630+
"class": "rec", "target": "classes",
6631+
"unit": {
6632+
"single_layer_att": single_layer_net_dict, # [B,T,F]
6633+
"multi_layer_att": None # [B,T,F], added below.
6634+
}}}
6635+
_build_self_attention_layer(
6636+
net_dict["output"], 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:classes',
6637+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6638+
net_dict["output"]["multi_layer_att"]["is_output_layer"] = True
6639+
else:
6640+
net_dict = {
6641+
"single_layer_att": single_layer_net_dict, # [B,T,F]
6642+
"multi_layer_att": None # [B,T,F], added below.
6643+
}
6644+
_build_self_attention_layer(
6645+
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
6646+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6647+
net_dict["multi_layer_att"]["is_output_layer"] = True
6648+
6649+
config = Config({"debug_print_layer_output_template": True, "debug_add_check_numerics_ops": True})
6650+
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
6651+
network = TFNetwork(config=config, train_flag=True)
6652+
network.construct_from_dict(net_dict)
6653+
6654+
if inside_rec_layer:
6655+
single_layer = network.get_layer("output/single_layer_att")
6656+
multi_layer = network.get_layer("output/multi_layer_att")
6657+
else:
6658+
single_layer = network.get_layer("single_layer_att")
6659+
multi_layer = network.get_layer("multi_layer_att")
6660+
6661+
assert_equal(single_layer.output.shape, (None, num_heads * value_dim))
6662+
assert_equal(multi_layer.output.shape, (None, num_heads * value_dim))
6663+
6664+
# set weights equal.
6665+
single_weights = single_layer.params["QKV"]
6666+
multi_weights = multi_layer.params["W"]
6667+
assert_equal(single_weights.shape, multi_weights.shape)
6668+
weights = numpy.random.rand(*single_weights.shape)
6669+
session.run(tf.assign(single_weights, weights))
6670+
session.run(tf.assign(multi_weights, weights))
6671+
6672+
# fetch/compare outputs
6673+
from tests.test_TFNetworkLayer import make_feed_dict
6674+
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
6675+
single, multi = session.run(
6676+
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
6677+
print('single layer output:')
6678+
pprint(single)
6679+
print('multi layer output:')
6680+
pprint(multi)
6681+
numpy.testing.assert_almost_equal(single, multi, decimal=5)
6682+
print('They are equal!')
6683+
6684+
6685+
def test_self_attention_optimize_out():
6686+
num_heads, key_dim, value_dim = 2, 3, 3
6687+
network = {}
6688+
_build_self_attention_layer(
6689+
network, 'data:source', 'att', inside_rec_layer=True, query_axis='stag:extern_data:data',
6690+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6691+
6692+
check_reclayer_optimize_out(
6693+
{'class': 'copy', 'from': 'att_att', 'n_out': value_dim * num_heads},
6694+
other_subnet_layers=network)
6695+
6696+
65536697
if __name__ == "__main__":
65546698
try:
65556699
better_exchook.install()

0 commit comments

Comments
 (0)