Skip to content

Conversation

@chichun-charlie-liu
Copy link
Collaborator

Description of the change

In addition to CUTLASS kernel, we added a new Triton matmul kernel that supports FP32, FP16, BF16, FP8, and INT8. Triton kernel is more flexible and easier hackable than CUTLASS. Although the INT8 performance of this triton kernel is only on par with FP16 torch.matmul (CUTLASS is ~2x faster), triton provides a valuable path to study HW behaviors and detailed simulations. For example, we can apply truncation on accumulator in a more efficient way than serial torch.matmul and much cleaner codes compared to writing in CUTLASS.

Related issue number

There is a compatibility issue with our existing CUTLASS kernel and torch.compile(..., mode=reduced-overhead), which is blocking us from advancing from PyTorch 2.3 to PyTorch 2.4. With the addition of the new Triton kernel, at least there is an alternative run path for the entire INT8 QAT example (including the lowering part) when using PyTorch 2.4.

How to verify the PR

INT8 QAT example has a lowering option which previously only supported CUTLASS. With the newly added Triton kernel, we have a second option to run quantized model using real INT engine now.

Note that in this model lowering experiment, quantized model will pass real INT8 tensor to matmul operator. Therefore, the kernel needs to be able to:

  1. accept INT8 tensors (torch.matmul do not accept INT8 inputs yet),
  2. use the INT engine in the best way it can
  3. return from INT8 matmul is a INT32 tensor. Dequantization will be performed afterward.

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

Copy link
Collaborator

@tharapalanivel tharapalanivel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @chichun-charlie-liu!

@chichun-charlie-liu chichun-charlie-liu merged commit 9301123 into foundation-model-stack:main Jan 30, 2025
11 checks passed
@chichun-charlie-liu chichun-charlie-liu deleted the triton-kernel branch January 30, 2025 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants