Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attn entropy calculation should not look at future tokens #62

Open
stillmatic opened this issue Oct 10, 2024 · 3 comments
Open

attn entropy calculation should not look at future tokens #62

stillmatic opened this issue Oct 10, 2024 · 3 comments

Comments

@stillmatic
Copy link

attention_probs = F.softmax(attention_scores, dim=-1)
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
attn_varentropy = torch.var(attn_entropy, dim=-1)

This calculation computes entropy over the attention scores for each position including future positions. Because future positions are zeroed attention-wise (causal mask) the resulting distribution is quite skewed and the varent calculation is NaN.

I changed up some logic to make the calculation more sensible (it's horribly inefficient - I would guess if we pass cur_pos to the metrics calculation and compute attn scores up to the cur_pos that might be better)

    attention_scores = torch.where(attention_scores != 0.0, attention_scores, torch.full_like(attention_scores, float('-inf')))
    non_inf_attn_scores = torch.where(attention_scores != float('-inf'), attention_scores, torch.full_like(attention_scores, torch.nan))
    interaction_strength = torch.nanmean(torch.abs(non_inf_attn_scores), dim=(1, 2, 3))

Does Jax treat this differently?

@samefarrar
Copy link

Beat me to it! #63

Jax does the same. I was playing with this today because the frog branch has some new thresholds, but the attention_entropy and attention_varentropy are oddly distributed and really affected by the current position. As the sequence gets longer, you have fewer 0s, so you end up with lower entropy but higher varentropy.

image

Masking the scores (in a very quick and dirty way) before putting them into metrics makes attention_entropy more stable (increasing over time), but that's more like what you'd expect.

logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
mask = jnp.arange(scores.shape[-1]) >= cur_pos
# Expand mask to match scores shape: (1, 32, 1, 4096)
mask = mask.reshape(1, 1, 1, -1)
scores = jnp.where(mask, DEFAULT_MASK_VALUE, scores)

@stillmatic
Copy link
Author

nice! yeah should def be masking vs doing what i did haha

this is what my runs look like now

image

@LeonEricsson
Copy link

Are you sure masking works here? I suspect this ruins the interaction strength calculation as it's computed on the raw attention scores as opposed to the probs

interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3))

can confirm this tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants