Skip to content

Conversation

@henrylhtsang
Copy link
Contributor

@henrylhtsang henrylhtsang commented Jan 7, 2026

note: work done by claude code

Replace deprecated cute DSL function calls with their new equivalents:

  • cute.make_fragmentcute.make_rmem_tensor
  • cute.make_fragment_likecute.make_rmem_tensor_like [Note: this doesn't contribute to the deprecation warning, see https://github.com/NVIDIA/cutlass/blob/f86feb0aa8a9490a7ab27bc991e36d7b5bf300e3/media/docs/pythonDSL/cute_dsl_api/changelog.rst#L22]
  • cute.arch.exp2(x)cute.math.exp2(x, fastmath=True)

Before:

cute/test_flash_attn.py: 1500 warnings
  /home/henrylhtsang/.conda/envs/flash/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

cute/test_flash_attn.py: 9440 warnings
  /home/henrylhtsang/.conda/envs/flash/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

After:

cd ~/flash-attention/tests/cute
pytest .
cd ~/flash-attention/tests/cute
pytest --collect-only -q 2>/dev/null | grep "::" | sed 's|^cute/||' | shuf | head -100 | xargs pytest -x

Latest test run after rebase

cd ~/flash-attention/tests/cute
pytest .
================================================================================================== warnings summary ==================================================================================================
cute/test_mask_mod.py: 71 warnings
  /home/henrylhtsang/flash-attention/flash_attn/cute/mask.py:367: DSLOptimizationWarning: This static loop has 128 iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.
    for i in cutlass.range_constexpr(ncol):

cute/test_mask_mod.py::test_mask_mod_ima_partial_block
  /home/henrylhtsang/.conda/envs/flash/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py:1687: UserWarning: flex_attention called without torch.compile() - this will use an unfused implementation that materializes the full scores matrix instead of generating a fused kernel.
  
  SOLUTION: Use torch.compile(flex_attention)(...)
  
  If you want to debug your score_mod/mask_mod, you can set:
  torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
  
  This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results.
    _warn_once(

cute/test_mask_mod.py: 46 warnings
  /home/henrylhtsang/flash-attention/flash_attn/cute/mask.py:510: DSLOptimizationWarning: This static loop has 64 iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.
    for i in cutlass.range_constexpr(ncol):

cute/test_mask_mod.py: 46 warnings
  /home/henrylhtsang/flash-attention/flash_attn/cute/mask.py:533: DSLOptimizationWarning: This static loop has 64 iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.
    for i in cutlass.range_constexpr(ncol):

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================== 49957 passed, 34769 skipped, 164 warnings in 20350.92s (5:39:10) ==========================================================================

@henrylhtsang henrylhtsang marked this pull request as draft January 7, 2026 23:27
@henrylhtsang henrylhtsang marked this pull request as ready for review January 7, 2026 23:46
Replace deprecated cute DSL function calls with their new equivalents:
- `cute.make_fragment` → `cute.make_rmem_tensor`
- `cute.make_fragment_like` → `cute.make_rmem_tensor_like`
- `cute.arch.exp2(x)` → `cute.math.exp2(x, fastmath=True)`

This fixes ~11k deprecation warnings when running the cute tests.
@henrylhtsang
Copy link
Contributor Author

maybe @jayhshah? I just finished testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant