Skip to content

Commit dc58c5a

Browse files
committed
Tests for self attention using CumConcatLayer
1 parent b2b5cb5 commit dc58c5a

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6837,6 +6837,160 @@ def test_RelativePositionalEncodingLayer():
68376837
print(out) # random...
68386838

68396839

6840+
def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
6841+
value_dim=64, dropout=0.0):
6842+
"""
6843+
Essentially this does
6844+
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6845+
"total_key_dim": num_heads * key_dim,
6846+
"n_out": num_heads * value_dim, "from": [input],
6847+
"attention_left_only": inside_rec_layer,
6848+
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
6849+
But using multiple layers.
6850+
"""
6851+
# Create (non-accumulated) query, key and value
6852+
d[output + '_qkv0'] = {
6853+
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
6854+
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
6855+
d[output + '_qkv'] = {
6856+
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
6857+
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
6858+
d[output + '_qkv_split'] = {
6859+
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim), 'from': [output + '_qkv']}
6860+
d[output + '_query'] = {'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
6861+
d[output + '_key'] = {'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
6862+
d[output + '_value'] = {'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]
6863+
6864+
# Accumulate keys/values or rename the axis
6865+
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
6866+
key_axis = 'stag:' + key_dim_tag.description
6867+
if inside_rec_layer:
6868+
d[output + '_key_accum'] = {
6869+
'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k]
6870+
d[output + '_value_accum'] = {
6871+
'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v]
6872+
else:
6873+
d[output + '_key_accum'] = {
6874+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6875+
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
6876+
d[output + '_value_accum'] = {
6877+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6878+
'from': [output + '_value']} # [B,T|keys,n,F|d_v]
6879+
6880+
# Calculate the energies
6881+
d[output + '_energy'] = {
6882+
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
6883+
'red1': 'static:-1', 'red2': 'static:-1',
6884+
'var1': None if inside_rec_layer else query_axis, 'var2': key_dim_tag} # [B,n,T?,T|rec-history]
6885+
6886+
d[output + '_weights'] = {
6887+
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
6888+
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
6889+
d[output + '_weights_drop'] = {
6890+
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
6891+
'dropout': dropout} # [B,n,T?,T|rec-history]
6892+
6893+
d[output + '_output'] = {
6894+
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
6895+
'red1': key_axis, 'red2': key_axis,
6896+
"var1": None if inside_rec_layer else query_axis, "var2": "static:-1"} # [B,n,T?,F|d_v]
6897+
d[output + '_att'] = {'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]
6898+
6899+
6900+
def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
6901+
n_time = 13
6902+
num_heads, key_dim, value_dim = 2, 3, 3
6903+
for inside_rec_layer in [False, True]:
6904+
with make_scope() as session:
6905+
print('Testing inside_rec_layer=%s' % inside_rec_layer)
6906+
6907+
# build net dict
6908+
if inside_rec_layer:
6909+
net_dict = {
6910+
"output": {
6911+
"class": "rec", "target": "classes", "from": [],
6912+
"unit": {
6913+
"single_layer_att": {
6914+
"class": "self_attention", "from": "prev:single_layer_att", "num_heads": num_heads,
6915+
"total_key_dim": num_heads * key_dim, "n_out": num_heads * value_dim,
6916+
"attention_left_only": inside_rec_layer, 'is_output_layer': True}, # [B,T,F]
6917+
"multi_layer_att": None, # [B,T,F], added below.
6918+
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}}}}
6919+
_build_self_attention_layer(
6920+
net_dict["output"]["unit"], 'prev:multi_layer_att', 'multi_layer', inside_rec_layer=True,
6921+
query_axis='stag:extern_data:classes', num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6922+
net_dict["output"]["unit"]["multi_layer_att"]["is_output_layer"] = True
6923+
net_dict["output"]["unit"]["multi_layer_qkv0"]["is_output_layer"] = True # we need to set the matrix here
6924+
else:
6925+
net_dict = {
6926+
"single_layer_att": {
6927+
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
6928+
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer,
6929+
'is_output_layer': True}, # [B,T,F]
6930+
"multi_layer_att": None, # [B,T,F], added below.
6931+
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}
6932+
}
6933+
_build_self_attention_layer(
6934+
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
6935+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6936+
net_dict["multi_layer_att"]["is_output_layer"] = True
6937+
6938+
config = Config({
6939+
"debug_print_layer_output_template": True, "optimize_move_layers_out": True})
6940+
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
6941+
network = TFNetwork(config=config, train_flag=True)
6942+
from pprint import pprint
6943+
pprint(net_dict)
6944+
network.construct_from_dict(net_dict)
6945+
6946+
if inside_rec_layer:
6947+
single_layer = network.get_layer("output/single_layer_att")
6948+
multi_layer = network.get_layer("output/multi_layer_att")
6949+
6950+
# Note: single_layer.params etc. do not contain the params, need to access rec cell directly
6951+
rec_layer = network.get_layer("output")
6952+
single_weights = rec_layer.cell.net.get_layer("single_layer_att").params["QKV"]
6953+
multi_weights = rec_layer.cell.net.get_layer("multi_layer_qkv0").params["W"]
6954+
else:
6955+
single_layer = network.get_layer("single_layer_att")
6956+
multi_layer = network.get_layer("multi_layer_att")
6957+
single_weights = single_layer.params["QKV"]
6958+
multi_weights = network.get_layer("multi_layer_qkv0").params["W"]
6959+
6960+
assert_equal(single_layer.output.batch_shape, (None, None, num_heads * value_dim))
6961+
assert_equal(multi_layer.output.batch_shape, (None, None, num_heads * value_dim))
6962+
6963+
# set weights equal.
6964+
assert_equal(single_weights.shape, multi_weights.shape)
6965+
weights = numpy.random.rand(*single_weights.shape)
6966+
session.run(tf.compat.v1.assign(single_weights, weights))
6967+
session.run(tf.compat.v1.assign(multi_weights, weights))
6968+
6969+
# fetch/compare outputs
6970+
from tests.test_TFNetworkLayer import make_feed_dict
6971+
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
6972+
single, multi = session.run(
6973+
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
6974+
print('single layer output:')
6975+
pprint(single)
6976+
print('multi layer output:')
6977+
pprint(multi)
6978+
numpy.testing.assert_almost_equal(single, multi, decimal=5)
6979+
print('They are equal!')
6980+
6981+
6982+
def test_self_attention_optimize_out():
6983+
num_heads, key_dim, value_dim = 2, 3, 3
6984+
network = {}
6985+
_build_self_attention_layer(
6986+
network, 'data:source', 'att', inside_rec_layer=True, query_axis='stag:extern_data:data',
6987+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6988+
6989+
check_reclayer_optimize_out(
6990+
{'class': 'copy', 'from': 'att_att', 'n_out': value_dim * num_heads},
6991+
other_subnet_layers=network)
6992+
6993+
68406994
if __name__ == "__main__":
68416995
try:
68426996
better_exchook.install()

0 commit comments

Comments
 (0)