Skip to content

Tests for self attention using CumConcatLayer #590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6837,6 +6837,148 @@ def test_RelativePositionalEncodingLayer():
print(out) # random...


def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
value_dim=64, dropout=0.0):
"""
Essentially this does
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
"total_key_dim": num_heads * key_dim,
"n_out": num_heads * value_dim, "from": [input],
"attention_left_only": inside_rec_layer,
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
But using multiple layers.
"""
# Create (non-accumulated) query, key and value
d[output + '_qkv0'] = {
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
d[output + '_qkv'] = {
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
d[output + '_qkv_split'] = {
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim), 'from': [output + '_qkv']}
d[output + '_query'] = {'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
d[output + '_key'] = {'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
d[output + '_value'] = {'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]

# Accumulate keys/values or rename the axis
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
key_axis = 'stag:' + key_dim_tag.description
if inside_rec_layer:
d[output + '_key_accum'] = {
'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k]
d[output + '_value_accum'] = {
'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v]
else:
d[output + '_key_accum'] = {
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
d[output + '_value_accum'] = {
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
'from': [output + '_value']} # [B,T|keys,n,F|d_v]

# Calculate the energies
d[output + '_energy'] = {
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
'red1': 'static:-1', 'red2': 'static:-1',
'var1': None if inside_rec_layer else query_axis, 'var2': key_dim_tag} # [B,n,T?,T|rec-history]

d[output + '_weights'] = {
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
d[output + '_weights_drop'] = {
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
'dropout': dropout} # [B,n,T?,T|rec-history]

d[output + '_output'] = {
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
'red1': key_axis, 'red2': key_axis,
"var1": None if inside_rec_layer else query_axis, "var2": "static:-1"} # [B,n,T?,F|d_v]
d[output + '_att'] = {'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]


def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
n_time = 13
num_heads, key_dim, value_dim = 2, 3, 3
for inside_rec_layer in [False, True]:
with make_scope() as session:
print('Testing inside_rec_layer=%s' % inside_rec_layer)

# build net dict
if inside_rec_layer:
net_dict = {
"output": {
"class": "rec", "target": "classes", "from": [],
"unit": {
"single_layer_att": {
"class": "self_attention", "from": "prev:single_layer_att", "num_heads": num_heads,
"total_key_dim": num_heads * key_dim, "n_out": num_heads * value_dim,
"attention_left_only": inside_rec_layer, 'is_output_layer': True}, # [B,T,F]
"multi_layer_att": None, # [B,T,F], added below.
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}}}}
_build_self_attention_layer(
net_dict["output"]["unit"], 'prev:multi_layer_att', 'multi_layer', inside_rec_layer=True,
query_axis='stag:extern_data:classes', num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
net_dict["output"]["unit"]["multi_layer_att"]["is_output_layer"] = True
net_dict["output"]["unit"]["multi_layer_qkv0"]["is_output_layer"] = True # we need to set the matrix here
else:
net_dict = {
"single_layer_att": {
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer,
'is_output_layer': True}, # [B,T,F]
"multi_layer_att": None, # [B,T,F], added below.
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}
}
_build_self_attention_layer(
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
net_dict["multi_layer_att"]["is_output_layer"] = True

config = Config({
"debug_print_layer_output_template": True, "optimize_move_layers_out": True})
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
network = TFNetwork(config=config, train_flag=True)
from pprint import pprint
pprint(net_dict)
network.construct_from_dict(net_dict)

if inside_rec_layer:
single_layer = network.get_layer("output/single_layer_att")
multi_layer = network.get_layer("output/multi_layer_att")

# Note: single_layer.params etc. do not contain the params, need to access rec cell directly
rec_layer = network.get_layer("output")
single_weights = rec_layer.cell.net.get_layer("single_layer_att").params["QKV"]
multi_weights = rec_layer.cell.net.get_layer("multi_layer_qkv0").params["W"]
else:
single_layer = network.get_layer("single_layer_att")
multi_layer = network.get_layer("multi_layer_att")
single_weights = single_layer.params["QKV"]
multi_weights = network.get_layer("multi_layer_qkv0").params["W"]

assert_equal(single_layer.output.batch_shape, (None, None, num_heads * value_dim))
assert_equal(multi_layer.output.batch_shape, (None, None, num_heads * value_dim))

# set weights equal.
assert_equal(single_weights.shape, multi_weights.shape)
weights = numpy.random.rand(*single_weights.shape)
session.run(tf.compat.v1.assign(single_weights, weights))
session.run(tf.compat.v1.assign(multi_weights, weights))

# fetch/compare outputs
from tests.test_TFNetworkLayer import make_feed_dict
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
single, multi = session.run(
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
print('single layer output:')
pprint(single)
print('multi layer output:')
pprint(multi)
numpy.testing.assert_almost_equal(single, multi, decimal=5)
print('They are equal!')


if __name__ == "__main__":
try:
better_exchook.install()
Expand Down