Skip to content

Incorrect results when using MultiHeadAttention attention_axes with negative indexing #21714

@AmedeoBiolatti

Description

@AmedeoBiolatti

Using negative indexing (eg. -2 for second last element) for attention_axes in MultiHeadAttention returns wrong results.

I think the issue is with _dot_product_equation, computed in _build_attention_equation

import numpy as np
import keras

# (batch0, batch1, seq, features)
x = np.random.normal(size=(10, 5, 128, 16))

# 0) flatten the batch dimensions
x_flat = keras.ops.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:])

mha = keras.layers.MultiHeadAttention(num_heads=4, key_dim=16)
z_flat, a_flat = mha(x_flat, x_flat, return_attention_scores=True)
z = keras.ops.reshape(z_flat, x.shape[:2] + z_flat.shape[1:])
a = keras.ops.reshape(a_flat, x.shape[:2] + a_flat.shape[1:])

print(z.shape, a.shape) # (10, 5, 128, 16) (10, 5, 4, 128, 128)

# 1) pass attention_axes with regular indexing
mha = keras.layers.MultiHeadAttention(num_heads=4, key_dim=16, attention_axes=2)
z, a = mha(x, x, return_attention_scores=True)

print(z.shape, a.shape) # (10, 5, 128, 16) (10, 5, 4, 128, 128)
print(mha._dot_product_equation) # abfde,abcde->abdcf

# 2) pass attention_axes with negative indexing <- NOT WORKING AS EXPECTED
mha = keras.layers.MultiHeadAttention(num_heads=4, key_dim=16, attention_axes=-2)
z, a = mha(x, x, return_attention_scores=True)

print(z.shape, a.shape) # (10, 5, 128, 16) (10, 5, 128, 4, 4) 
print(mha._dot_product_equation) # abcfe,abcde->abcdf

See this gist:
https://gist.github.com/AmedeoBiolatti/45fdd330f06d94fbe0dad51f4a3d0f2e

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions