Skip to content

Conversation

@Iamleos
Copy link

@Iamleos Iamleos commented Sep 30, 2025

No description provided.

@gemini-code-assist
Copy link

Summary of Changes

Hello @Iamleos, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant feature to the Flash Attention mechanism: the ability to apply custom attention masks. This enhancement provides greater flexibility for advanced attention patterns, such as those required for speculative decoding verification. The changes involve modifying the underlying Pallas kernel, updating the reference implementation, integrating the new parameters into the attention backend, and adding comprehensive test coverage to ensure correctness.

Highlights

  • Custom Mask Support in Flash Attention Kernel: The core Flash Attention kernel (_ragged_paged_attention_kernel) has been enhanced to accept and utilize a custom_mask array and a causal boolean flag. This allows for flexible attention masking beyond standard causal patterns.
  • Reference Implementation and Validation: The reference implementation (ref_ragged_paged_attention) now supports custom_mask and causal parameters, including validation logic to prevent conflicting usage (e.g., using a custom mask when causal masking is enabled).
  • Integration with Pallas Call: The ragged_paged_attention function, which orchestrates the Pallas kernel call, has been updated to pass the new causal and custom_mask parameters, modifying static_argnames, input specifications (in_specs), and scratch memory allocations accordingly.
  • Attention Backend Updates: The FlashAttentionMetadata class and the FlashAttentionBackend now include support for custom_mask, enabling its propagation through the attention pipeline, particularly when ForwardMode.TARGET_VERIFY is active.
  • New Test Coverage for Custom Masks: New test cases, including a create_custom_mask helper function and test_mha_prefill_with_custom_mask, have been added to verify the correctness and functionality of the custom mask feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for custom attention masks in the flash attention kernel, which is a valuable feature for non-causal attention patterns like those in speculative decoding. The changes are comprehensive, affecting the reference implementation, the optimized Pallas kernel, and the testing suite. While the core logic for the custom mask in the kernel seems sound, I've identified several issues, including critical bugs in the reference implementation and test setup that will prevent the code from running as intended. There are also some inconsistencies in mask handling between the reference and kernel implementations that should be reconciled. My review provides specific suggestions to address these points and improve the overall quality and correctness of the implementation.

mask_start = cu_kv_lens[i]
mask = custom_mask[mask_start : mask_start + kv_len]
if sliding_window is not None:
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line will cause a NameError when causal=False and sliding_window is not None, because q_span and kv_span are only defined within the if causal: block.

Since sliding window attention is a form of causal masking, it probably doesn't make sense to use it with a custom non-causal mask. You should consider raising a ValueError if causal=False and sliding_window is provided, or define q_span and kv_span for the non-causal case if this combination is intended to be supported.

Comment on lines 118 to 121
if custom_mask == None or custom_mask.size() < jnp.cumsum(kv_lens)[-1]:
raise ValueError(
f"use custom_mask, custom_mask length must larger than total kv length"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a bug here. .size is an attribute of a JAX array, not a method. Calling custom_mask.size() will raise a TypeError. It should be custom_mask.size.

Additionally, for style and correctness, it's better to use is None and is not None for checking against None instead of == None and != None.

Suggested change
if custom_mask == None or custom_mask.size() < jnp.cumsum(kv_lens)[-1]:
raise ValueError(
f"use custom_mask, custom_mask length must larger than total kv length"
)
if custom_mask is None or custom_mask.size < jnp.cumsum(kv_lens)[-1]:
raise ValueError(
f"use custom_mask, custom_mask length must larger than total kv length"
)

Comment on lines 255 to 257
spec_info = EagleVerifyInput(
custom_mask=custom_mask,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Instantiating EagleVerifyInput with only custom_mask will raise a TypeError because the EagleVerifyInput dataclass has other required fields that do not have default values. The test will fail to run. You need to provide all required arguments to instantiate EagleVerifyInput correctly.

cu_q_lens: jax.Array, # i32[padded_batch_size + 1]
cu_kv_lens: jax.Array, # i32[padded_batch_size + 1]
distribution: jax.Array, # i32[3]
custom_mask: jax.Array, # if causal is True, custom_mask shape is [patten_total_kv_len], else [0]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment for custom_mask is confusing and seems to have the logic reversed. The custom mask is used when causal is False (i.e., causal=0), not when it's True. Please correct the comment to avoid confusion.

Suggested change
custom_mask: jax.Array, # if causal is True, custom_mask shape is [patten_total_kv_len], else [0]
custom_mask: jax.Array, # if causal is False, custom_mask shape is [patten_total_kv_len], else [0]

Comment on lines 1409 to 1411
print(
f"{cu_q_lens.shape=} {cu_kv_lens.shape=} {bq_sz=} {bkv_p=} {distribution=} {chunk_prefill_size=}"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It should be removed before merging.

Comment on lines +318 to +322
if len(mode_args) == 5:
num_heads, head_dim, num_kv_heads, page_size, dtype = mode_args
causal = True
else:
num_heads, head_dim, num_kv_heads, page_size, dtype, causal = mode_args

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Unpacking mode_args based on its length is fragile. This can easily break if more arguments are added in the future. Consider refactoring run_test to accept causal as a keyword argument, or use a more structured approach for passing these test parameters.

Comment on lines +658 to +665
def test_mha_decode_with_custom_mask(self):
pass

def test_gqa_prefill_with_custom_mask(self):
pass

def test_gqa_decode_with_custom_mask(self):
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These test cases are currently empty placeholders. To ensure complete test coverage for the new custom mask functionality, these tests should be implemented. They should cover both MHA/GQA attention and prefill/decode modes with custom masks.

@Iamleos Iamleos force-pushed the fix/eagle-kkx branch 7 times, most recently from 6054f6b to f5b4347 Compare September 30, 2025 10:14
@Iamleos Iamleos force-pushed the fix/eagle-kkx branch 2 times, most recently from 3bd665d to 4d18b12 Compare September 30, 2025 10:22
@Iamleos Iamleos force-pushed the fix/eagle-kkx branch 2 times, most recently from 115b2da to f79dd69 Compare September 30, 2025 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants