Skip to content

Conversation

@ErwinTerpstra
Copy link

Proposed changes

This PR adds support for running grouped gemm operations on RDNA3/4 using WMMA instructions. The PR contains:

  • Device struct implementation for grouped gemm on WMMA using GridwiseGemm_wmma_cshuffle_v3
  • Examples to compile a single instance for FP16 and BF16
  • Instances for all combinations of:
    • Input tensor layout
    • FP16 and BF16
    • Full padding and no padding
    • Pipeline V1 and V3
  • Also two specific F8+F16/F16+F8 instances that were included in the original XDL version
  • GTests for all instances
    • Note: I've changed the profile_grouped_gemm_impl function to accept a parameter that makes it fail the test if no supported instances could be found. Previously it would silently pass the test. The parameter is optional and defaults to the old behaviour to not break old tests
  • Other changes:
    • Some changes to the GridwiseGemm_wmma_cshuffle_v3::Run() interface to allow passing in a custom Block2CTile map, which was necessary to handle the non-uniform dimensions of grouped gemm
    • Add extra CMake options to fully disable compiling certain instances, even if they are supported by the current gpu targets FORCE_DISABLE_XDL and FORCE_DISABLE_WMMA)
    • Some minor changes required since the grouped gemm implementation copies argument structs:
      • Made the base argument destructor available on device code (as this is required in this specific case with how the SplitK functionality works by adjusting the argument struct)
      • Made the element-wise operator fields in the argument struct non-const (to allow using compiler generated move assignment operator)

Other algorithm variants for grouped gemm will be added as follow-up PRs.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@illsilin illsilin self-assigned this Nov 20, 2025
BaseArgument& operator=(const BaseArgument&) = default;

virtual ~BaseArgument() {}
virtual __host__ __device__ ~BaseArgument() {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use CK_TILE_HOST_DEVICE macro?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is "old" CK, not CK Tile. I'd imagine we don't want to include the header for CK tile here?

I looked for a similar macro in old CK, but there doesn't seem to be one. Other code also defines the attributes inline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Thank you for the response!

@illsilin
Copy link
Collaborator

Looks good.
Just one small request. Aviral is currently fixing the copyright headers in all files to the new date-less format:

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

Could you please add these to every file that's missing copyright headers and replace the old ones with this new one?

@ErwinTerpstra
Copy link
Author

ErwinTerpstra commented Nov 24, 2025

Looks good. Just one small request. Aviral is currently fixing the copyright headers in all files to the new date-less format:

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT

Could you please add these to every file that's missing copyright headers and replace the old ones with this new one?

Done, added it to all new/moved files. Will add this in the future as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants