-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Open
Labels
Description
Using negative indexing (eg. -2 for second last element) for attention_axes
in MultiHeadAttention
returns wrong results.
I think the issue is with _dot_product_equation
, computed in _build_attention_equation
import numpy as np
import keras
# (batch0, batch1, seq, features)
x = np.random.normal(size=(10, 5, 128, 16))
# 0) flatten the batch dimensions
x_flat = keras.ops.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:])
mha = keras.layers.MultiHeadAttention(num_heads=4, key_dim=16)
z_flat, a_flat = mha(x_flat, x_flat, return_attention_scores=True)
z = keras.ops.reshape(z_flat, x.shape[:2] + z_flat.shape[1:])
a = keras.ops.reshape(a_flat, x.shape[:2] + a_flat.shape[1:])
print(z.shape, a.shape) # (10, 5, 128, 16) (10, 5, 4, 128, 128)
# 1) pass attention_axes with regular indexing
mha = keras.layers.MultiHeadAttention(num_heads=4, key_dim=16, attention_axes=2)
z, a = mha(x, x, return_attention_scores=True)
print(z.shape, a.shape) # (10, 5, 128, 16) (10, 5, 4, 128, 128)
print(mha._dot_product_equation) # abfde,abcde->abdcf
# 2) pass attention_axes with negative indexing <- NOT WORKING AS EXPECTED
mha = keras.layers.MultiHeadAttention(num_heads=4, key_dim=16, attention_axes=-2)
z, a = mha(x, x, return_attention_scores=True)
print(z.shape, a.shape) # (10, 5, 128, 16) (10, 5, 128, 4, 4)
print(mha._dot_product_equation) # abcfe,abcde->abcdf
See this gist:
https://gist.github.com/AmedeoBiolatti/45fdd330f06d94fbe0dad51f4a3d0f2e