-
Notifications
You must be signed in to change notification settings - Fork 153
[Advanced Complier]Baddbmm #1051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the baddbmm operator, complete with a Triton kernel, forward and backward passes, benchmarks, and accuracy tests. The overall implementation is solid, but I've identified a critical bug in the backward pass concerning gradient calculation for the bias tensor. Additionally, there are some issues in the benchmark code and opportunities for code simplification and improved test coverage. My review provides specific suggestions to address these points.
| def compute_bias_grad(d_output, beta): | ||
| batch, M, N = d_output.shape | ||
| d_bias = torch.zeros((M, N), device=d_output.device, dtype=d_output.dtype) | ||
| d_bias = mul(d_output, beta) | ||
| return d_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute_bias_grad function is incorrect as it doesn't handle gradient accumulation for broadcasted bias tensors. It should sum the gradient over the broadcasted dimensions to match the original bias shape. Currently, it returns a gradient with the same shape as grad_output, which is only correct if bias was not broadcasted.
This is a critical correctness issue. The existing backward test does not catch this because it uses a non-broadcasted bias tensor.
To fix this, you first need to pass the original bias tensor from BaddbmmFunction.backward to this function (at line 178):
# in BaddbmmFunction.backward
grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)Then, update compute_bias_grad to handle the gradient reduction. Also, line 189 is dead code and can be removed.
def compute_bias_grad(d_output, beta, bias):
grad_bias = mul(d_output, beta)
if grad_bias.shape != bias.shape:
# Sum over broadcasted dimensions
while grad_bias.dim() > bias.dim():
grad_bias = grad_bias.sum(dim=0)
for i in range(bias.dim()):
if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
grad_bias = grad_bias.sum(dim=i, keepdim=True)
return grad_bias.view(bias.shape)| # shape(b,m,n)(b,n,p) | ||
| # total_flops bxmxpx(2n+1) | ||
| elif self.op_name == "baddbmm": | ||
| total_flops = ( | ||
| args[0].shape[0] | ||
| * args[0].shape[1] | ||
| * args[1].shape[2] | ||
| * (args[1].shape[1] * 2 + 1) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TFLOPS calculation for baddbmm is incorrect. It uses the wrong input tensors for its calculation (args[0] is bias, but it should be using args[1] and args[2] which are mat1 and mat2). The formula itself is also incorrect.
For baddbmm with mat1 of shape (b, m, k) and mat2 of shape (b, k, n), the total floating point operations should be b * m * n * (2 * k + 1). The comment should also be updated to reflect the correct shapes and formula.
| # shape(b,m,n)(b,n,p) | |
| # total_flops bxmxpx(2n+1) | |
| elif self.op_name == "baddbmm": | |
| total_flops = ( | |
| args[0].shape[0] | |
| * args[0].shape[1] | |
| * args[1].shape[2] | |
| * (args[1].shape[1] * 2 + 1) | |
| ) | |
| # mat1: (b, m, k), mat2: (b, k, n) | |
| # total_flops = b * m * n * (2 * k + 1) | |
| elif self.op_name == "baddbmm": | |
| total_flops = ( | |
| args[1].shape[0] # b | |
| * args[1].shape[1] # m | |
| * args[2].shape[2] # n | |
| * (args[1].shape[2] * 2 + 1) # 2k+1 | |
| ) |
| @pytest.mark.parametrize("M, N, K", MNK_SHAPES) | ||
| @pytest.mark.parametrize("scalar", SCALARS) | ||
| @pytest.mark.parametrize("dtype", FLOAT_DTYPES) | ||
| def test_accuracy_baddbmm_backward(M, N, K, scalar, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backward test for baddbmm only covers the case where the bias tensor is not broadcasted. This is insufficient to catch the bug in compute_bias_grad. Please add test cases for different bias shapes that require broadcasting, such as (M, N), (N,), and (1,), to ensure the backward pass is correct for all supported scenarios. You can achieve this by parameterizing the test over different bias shapes.
| filtered_model_shaps = [] | ||
| for shape in model_shaps: | ||
| if shape not in skip_shapes: | ||
| filtered_model_shaps.append(shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop can be expressed more concisely and efficiently using a list comprehension.
| filtered_model_shaps = [] | |
| for shape in model_shaps: | |
| if shape not in skip_shapes: | |
| filtered_model_shaps.append(shape) | |
| filtered_model_shaps = [shape for shape in model_shaps if shape not in skip_shapes] |
| def baddbmm_heur_divisible_m(args): | ||
| return args["M"] % args["TILE_M"] == 0 | ||
|
|
||
|
|
||
| def baddbmm_heur_divisible_n(args): | ||
| return args["N"] % args["TILE_N"] == 0 | ||
|
|
||
|
|
||
| def baddbmm_heur_divisible_k(args): | ||
| return args["K"] % args["TILE_K"] == 0 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "baddbmm": { | ||
| "DIVISIBLE_M": baddbmm_heur_divisible_m, | ||
| "DIVISIBLE_N": baddbmm_heur_divisible_n, | ||
| "DIVISIBLE_K": baddbmm_heur_divisible_k, | ||
| }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce code duplication, you can reuse the existing bmm_heur_divisible_* functions here, as they are identical to the newly added baddbmm_heur_divisible_* functions. This change should be made after removing the new baddbmm_heur_divisible_* functions.
| "baddbmm": { | |
| "DIVISIBLE_M": baddbmm_heur_divisible_m, | |
| "DIVISIBLE_N": baddbmm_heur_divisible_n, | |
| "DIVISIBLE_K": baddbmm_heur_divisible_k, | |
| }, | |
| "baddbmm": { | |
| "DIVISIBLE_M": bmm_heur_divisible_m, | |
| "DIVISIBLE_N": bmm_heur_divisible_n, | |
| "DIVISIBLE_K": bmm_heur_divisible_k, | |
| }, |
| @triton.autotune( | ||
| configs=runtime.get_tuned_config("baddbmm"), | ||
| key=["M", "N", "K"], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz introduce a strategy parameter here to avoid triggering autotuning too frequently
You can refer to how the strategy parameter is used in mm.py.
|
|
||
| model_shaps = model_shapes() | ||
| return large_k_shapes + model_shaps | ||
| if self.op_name == "baddbmm": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can define a custom BaddbmmBenchmark class, add your benchmark shapes in core_shapes.yaml, and override the set_more_shapes method to customize the shape setting.
|
“ph0375” seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
PR Category
Operator
Type of Change
New Feature
Description
add baddbmm
Issue
Progress
Performance
baddbmm算子性能测试情况(1).xlsx