Skip to content

Conversation

@AdvancedCompiler
Copy link
Contributor

PR Category

Operator

Type of Change

New Feature

Description

add baddbmm

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

baddbmm算子性能测试情况(1).xlsx

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the baddbmm operator into the system, enhancing its capabilities for batched matrix operations with bias. The implementation leverages Triton for high-performance GPU execution and includes full autograd support, ensuring it can be used effectively in deep learning models. The changes also cover performance benchmarking and thorough testing to validate the new operator's correctness and efficiency.

Highlights

  • New Operator: baddbmm: Introduced the baddbmm operator, which performs a batched matrix multiplication with an added bias, implemented using a custom Triton kernel for optimized performance.
  • Autograd Support: Implemented a torch.autograd.Function for baddbmm, providing a custom backward pass for efficient gradient computation.
  • Performance Benchmarking and Tuning: Added baddbmm to the BLAS performance benchmark suite, including FLOPs calculation, input generation, and specific tuning configurations for NVIDIA GPUs.
  • Comprehensive Testing: Included new accuracy tests for both the forward and backward passes of the baddbmm operator to ensure correctness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the baddbmm operator, complete with a Triton kernel, forward and backward passes, benchmarks, and accuracy tests. The overall implementation is solid, but I've identified a critical bug in the backward pass concerning gradient calculation for the bias tensor. Additionally, there are some issues in the benchmark code and opportunities for code simplification and improved test coverage. My review provides specific suggestions to address these points.

Comment on lines +187 to +191
def compute_bias_grad(d_output, beta):
batch, M, N = d_output.shape
d_bias = torch.zeros((M, N), device=d_output.device, dtype=d_output.dtype)
d_bias = mul(d_output, beta)
return d_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The compute_bias_grad function is incorrect as it doesn't handle gradient accumulation for broadcasted bias tensors. It should sum the gradient over the broadcasted dimensions to match the original bias shape. Currently, it returns a gradient with the same shape as grad_output, which is only correct if bias was not broadcasted.

This is a critical correctness issue. The existing backward test does not catch this because it uses a non-broadcasted bias tensor.

To fix this, you first need to pass the original bias tensor from BaddbmmFunction.backward to this function (at line 178):

# in BaddbmmFunction.backward
grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)

Then, update compute_bias_grad to handle the gradient reduction. Also, line 189 is dead code and can be removed.

def compute_bias_grad(d_output, beta, bias):
    grad_bias = mul(d_output, beta)
    if grad_bias.shape != bias.shape:
        # Sum over broadcasted dimensions
        while grad_bias.dim() > bias.dim():
            grad_bias = grad_bias.sum(dim=0)
        for i in range(bias.dim()):
            if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
                grad_bias = grad_bias.sum(dim=i, keepdim=True)
    return grad_bias.view(bias.shape)

Comment on lines +78 to +86
# shape(b,m,n)(b,n,p)
# total_flops bxmxpx(2n+1)
elif self.op_name == "baddbmm":
total_flops = (
args[0].shape[0]
* args[0].shape[1]
* args[1].shape[2]
* (args[1].shape[1] * 2 + 1)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TFLOPS calculation for baddbmm is incorrect. It uses the wrong input tensors for its calculation (args[0] is bias, but it should be using args[1] and args[2] which are mat1 and mat2). The formula itself is also incorrect.

For baddbmm with mat1 of shape (b, m, k) and mat2 of shape (b, k, n), the total floating point operations should be b * m * n * (2 * k + 1). The comment should also be updated to reflect the correct shapes and formula.

Suggested change
# shape(b,m,n)(b,n,p)
# total_flops bxmxpx(2n+1)
elif self.op_name == "baddbmm":
total_flops = (
args[0].shape[0]
* args[0].shape[1]
* args[1].shape[2]
* (args[1].shape[1] * 2 + 1)
)
# mat1: (b, m, k), mat2: (b, k, n)
# total_flops = b * m * n * (2 * k + 1)
elif self.op_name == "baddbmm":
total_flops = (
args[1].shape[0] # b
* args[1].shape[1] # m
* args[2].shape[2] # n
* (args[1].shape[2] * 2 + 1) # 2k+1
)

@pytest.mark.parametrize("M, N, K", MNK_SHAPES)
@pytest.mark.parametrize("scalar", SCALARS)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_baddbmm_backward(M, N, K, scalar, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backward test for baddbmm only covers the case where the bias tensor is not broadcasted. This is insufficient to catch the bug in compute_bias_grad. Please add test cases for different bias shapes that require broadcasting, such as (M, N), (N,), and (1,), to ensure the backward pass is correct for all supported scenarios. You can achieve this by parameterizing the test over different bias shapes.

Comment on lines +49 to +52
filtered_model_shaps = []
for shape in model_shaps:
if shape not in skip_shapes:
filtered_model_shaps.append(shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop can be expressed more concisely and efficiently using a list comprehension.

Suggested change
filtered_model_shaps = []
for shape in model_shaps:
if shape not in skip_shapes:
filtered_model_shaps.append(shape)
filtered_model_shaps = [shape for shape in model_shaps if shape not in skip_shapes]

Comment on lines +37 to +47
def baddbmm_heur_divisible_m(args):
return args["M"] % args["TILE_M"] == 0


def baddbmm_heur_divisible_n(args):
return args["N"] % args["TILE_N"] == 0


def baddbmm_heur_divisible_k(args):
return args["K"] % args["TILE_K"] == 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These new heuristic functions are identical to the existing bmm_heur_divisible_* functions. To avoid code duplication, you can remove these functions and reuse the bmm versions in the HEURISTICS_CONFIGS map for baddbmm.

Comment on lines +261 to +265
"baddbmm": {
"DIVISIBLE_M": baddbmm_heur_divisible_m,
"DIVISIBLE_N": baddbmm_heur_divisible_n,
"DIVISIBLE_K": baddbmm_heur_divisible_k,
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To reduce code duplication, you can reuse the existing bmm_heur_divisible_* functions here, as they are identical to the newly added baddbmm_heur_divisible_* functions. This change should be made after removing the new baddbmm_heur_divisible_* functions.

Suggested change
"baddbmm": {
"DIVISIBLE_M": baddbmm_heur_divisible_m,
"DIVISIBLE_N": baddbmm_heur_divisible_n,
"DIVISIBLE_K": baddbmm_heur_divisible_k,
},
"baddbmm": {
"DIVISIBLE_M": bmm_heur_divisible_m,
"DIVISIBLE_N": bmm_heur_divisible_n,
"DIVISIBLE_K": bmm_heur_divisible_k,
},

@triton.autotune(
configs=runtime.get_tuned_config("baddbmm"),
key=["M", "N", "K"],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz introduce a strategy parameter here to avoid triggering autotuning too frequently
You can refer to how the strategy parameter is used in mm.py.


model_shaps = model_shapes()
return large_k_shapes + model_shaps
if self.op_name == "baddbmm":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can define a custom BaddbmmBenchmark class, add your benchmark shapes in core_shapes.yaml, and override the set_more_shapes method to customize the shape setting.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ AdvancedCompiler
❌ “ph0375”


“ph0375” seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants