Skip to content

Commit 66cb35e

Browse files
committed
Tests for self attention using CumConcatLayer
1 parent 210cd38 commit 66cb35e

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

tests/test_TFNetworkRecLayer.py

+147
Original file line numberDiff line numberDiff line change
@@ -6550,6 +6550,153 @@ def test_RelativePositionalEncodingLayer():
65506550
print(out) # random...
65516551

65526552

6553+
def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
6554+
value_dim=64, dropout=0.0):
6555+
"""
6556+
Essentially this does
6557+
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6558+
"total_key_dim": num_heads * key_dim,
6559+
"n_out": num_heads * value_dim, "from": [input],
6560+
"attention_left_only": inside_rec_layer,
6561+
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
6562+
But using multiple layers.
6563+
"""
6564+
# Create (non-accumulated) query, key and value
6565+
d[output + '_qkv0'] = {
6566+
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
6567+
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
6568+
d[output + '_qkv'] = {
6569+
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
6570+
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
6571+
d[output + '_qkv_split'] = {
6572+
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim),
6573+
'from': [output + '_qkv']}
6574+
d[output + '_query'] = {
6575+
'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
6576+
d[output + '_key'] = {
6577+
'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
6578+
d[output + '_value'] = {
6579+
'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]
6580+
6581+
# Accumulate keys/values or rename the axis
6582+
if inside_rec_layer:
6583+
d[output + '_key_accum'] = {
6584+
'class': 'cum_concat', 'from': [output + '_key']} # [B,T|rec-history,n,F|d_k]
6585+
d[output + '_value_accum'] = {
6586+
'class': 'cum_concat', 'from': [output + '_value']} # [B,T|rec-history,n,F|d_v]
6587+
key_axis = 'stag:rec-history'
6588+
else:
6589+
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
6590+
d[output + '_key_accum'] = {
6591+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6592+
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
6593+
d[output + '_value_accum'] = {
6594+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6595+
'from': [output + '_value']} # [B,T|keys,n,F|d_v]
6596+
key_axis = 'stag:' + key_dim_tag.description
6597+
6598+
# Calculate the energies
6599+
d[output + '_energy'] = {
6600+
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
6601+
'red1': 'static:-1', 'red2': 'static:-1', 'common': ['B', 'static:0']} # [B,n,T?,T|rec-history]
6602+
6603+
d[output + '_weights'] = {
6604+
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
6605+
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
6606+
d[output + '_weights_drop'] = {
6607+
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
6608+
'dropout': dropout} # [B,n,T?,T|rec-history]
6609+
6610+
d[output + '_output'] = {
6611+
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
6612+
'red1': key_axis, 'red2': key_axis, 'common': ['B', query_axis, 'static:0']} # [B,n,T?,F|d_v]
6613+
d[output + '_att'] = {
6614+
'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]
6615+
6616+
6617+
def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
6618+
n_time = 13
6619+
num_heads, key_dim, value_dim = 2, 3, 3
6620+
for inside_rec_layer in [False, True]:
6621+
with make_scope() as session:
6622+
print('Testing inside_rec_layer=%s' % inside_rec_layer)
6623+
6624+
# build net dict
6625+
single_layer_net_dict = {
6626+
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
6627+
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer, 'is_output_layer': True} # [B,T,F]
6628+
if inside_rec_layer:
6629+
net_dict = {
6630+
"output": {
6631+
"class": "rec", "target": "classes",
6632+
"unit": {
6633+
"single_layer_att": single_layer_net_dict, # [B,T,F]
6634+
"multi_layer_att": None # [B,T,F], added below.
6635+
}}}
6636+
_build_self_attention_layer(
6637+
net_dict["output"], 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
6638+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6639+
net_dict["output"]["multi_layer_att"]["is_output_layer"] = True
6640+
time_axis = 'stag:extern_data:classes'
6641+
else:
6642+
net_dict = {
6643+
"single_layer_att": single_layer_net_dict, # [B,T,F]
6644+
"multi_layer_att": None # [B,T,F], added below.
6645+
}
6646+
_build_self_attention_layer(
6647+
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
6648+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6649+
net_dict["multi_layer_att"]["is_output_layer"] = True
6650+
time_axis = 'stag:extern_data:data'
6651+
6652+
config = Config({"debug_print_layer_output_template": True, "debug_add_check_numerics_ops": True})
6653+
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
6654+
network = TFNetwork(config=config, train_flag=True)
6655+
network.construct_from_dict(net_dict)
6656+
6657+
if inside_rec_layer:
6658+
single_layer = network.get_layer("output/single_layer_att")
6659+
multi_layer = network.get_layer("output/multi_layer_att")
6660+
else:
6661+
single_layer = network.get_layer("single_layer_att")
6662+
multi_layer = network.get_layer("multi_layer_att")
6663+
6664+
assert_equal(single_layer.output.shape, (None, num_heads * value_dim))
6665+
assert_equal(multi_layer.output.shape, (None, num_heads * value_dim))
6666+
6667+
# set weights equal.
6668+
single_weights = single_layer.params["QKV"]
6669+
multi_weights = multi_layer.params["W"]
6670+
assert_equal(single_weights.shape, multi_weights.shape)
6671+
weights = numpy.random.rand(*single_weights.shape)
6672+
session.run(tf.assign(single_weights, weights))
6673+
session.run(tf.assign(multi_weights, weights))
6674+
6675+
# fetch/compare outputs
6676+
from tests.test_TFNetworkLayer import make_feed_dict
6677+
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
6678+
single, multi = session.run(
6679+
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
6680+
print('single layer output:')
6681+
pprint(single)
6682+
print('multi layer output:')
6683+
pprint(multi)
6684+
numpy.testing.assert_almost_equal(single, multi, decimal=5)
6685+
print('They are equal!')
6686+
6687+
6688+
def test_self_attention_optimize_out():
6689+
num_heads, key_dim, value_dim = 2, 3, 3
6690+
network = {}
6691+
_build_self_attention_layer(
6692+
network, 'data:source', 'att', inside_rec_layer=True, query_axis='stag:extern_data:data',
6693+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6694+
6695+
check_reclayer_optimize_out(
6696+
{'class': 'copy', 'from': 'att_att', 'n_out': value_dim * num_heads},
6697+
other_subnet_layers=network)
6698+
6699+
65536700
if __name__ == "__main__":
65546701
try:
65556702
better_exchook.install()

0 commit comments

Comments
 (0)