Skip to content

Conversation

@botbigeyes
Copy link
Contributor

@botbigeyes botbigeyes commented Oct 30, 2025

PR Category
Operator

Type of Change
Bug Fix

Description
The issue with unsupported bool type indices has been fixed.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance
accumulate=False:

Operator: index_put  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.370992            0.374848               0.990          [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS               0.009920            0.009504               1.044          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS               0.012768            0.009664               1.321          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS               0.010592            0.009888               1.071          [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS               0.011936            0.011296               1.057          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS               0.016768            0.011520               1.456          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS               0.013280            0.012208               1.088          [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS               0.188864            0.186736               1.011          [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.195584            0.186304               1.050          [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.453216            0.228400               1.984          [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]
Operator: index_put  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.719808            0.724256               0.994          [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS               0.009984            0.009536               1.047          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS               0.012544            0.009440               1.329          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS               0.010976            0.010720               1.024          [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS               0.013216            0.012800               1.033          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS               0.017152            0.012864               1.333          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS               0.014560            0.013728               1.061          [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS               0.366272            0.364512               1.005          [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.373088            0.363904               1.025          [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.625728            0.447040               1.400          [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]
Operator: index_put  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.369008            0.374592               0.985          [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS               0.009952            0.009536               1.044          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS               0.013888            0.009408               1.476          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS               0.010912            0.009648               1.131          [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS               0.011808            0.011424               1.034          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS               0.015712            0.011680               1.345          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS               0.013312            0.012384               1.075          [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS               0.188736            0.186624               1.011          [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.195600            0.186336               1.050          [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.453248            0.227648               1.991          [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]

accumulate=False and bool:

Operator: index_put  Performance Test (dtype=torch.bool, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.192096            0.196032               0.980          [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS               0.009600            0.009392               1.022          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS               0.012384            0.009376               1.321          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS               0.010560            0.009856               1.071          [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS               0.010720            0.010464               1.024          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS               0.014688            0.010752               1.366          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS               0.012448            0.010944               1.137          [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS               0.100544            0.098816               1.017          [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.107808            0.098144               1.098          [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS               0.361824            0.125248               2.889          [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]

accumulate=True

Operator: index_put  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.498336            0.376224               1.325          [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), True]
SUCCESS               0.074368            0.010144               7.331          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), True]
SUCCESS               0.080256            0.012000               6.688          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), True]
SUCCESS               0.286240            0.187104               1.530          [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), True]
SUCCESS               0.286640            0.187136               1.532          [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])], torch.Size([2, 128]), True]
Operator: index_put  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.852768            0.723968               1.178          [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), True]
SUCCESS               0.074016            0.009440               7.841          [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), True]
SUCCESS               0.081376            0.012896               6.310          [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), True]
SUCCESS               0.465376            0.364608               1.276          [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), True]
SUCCESS               0.464640            0.365104               1.273          [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])], torch.Size([2, 128]), True]

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @botbigeyes, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug where boolean type indices were not properly handled by the index_put and index_put_ operations. The changes introduce robust logic to correctly interpret and apply boolean masks, converting them into explicit coordinate indices. This ensures that these essential tensor manipulation functions now fully support boolean indexing, enhancing their flexibility and preventing errors when users provide boolean tensors for selection.

Highlights

  • Boolean Indexing Support: Implemented support for boolean type indices in the index_put and index_put_ functions. This involves converting a single boolean mask into coordinate indices using torch.where and adjusting the values tensor to match the number of elements identified by the mask.
  • Value Tensor Reshaping: Added logic to handle the values tensor when boolean indexing is used. If values is a scalar, it's expanded to match the size of the masked elements; otherwise, it's reshaped to ensure compatibility with the new indices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue with boolean type indices in index_put and index_put_. The approach of converting the boolean mask to indices using torch.where is correct. However, I've found a couple of issues in the implementation for handling the values tensor that could lead to runtime errors. I've also suggested a refactoring to reduce code duplication between the two functions. Overall, this is a good fix, but some improvements are needed for robustness and maintainability.

Comment on lines 268 to 273
if values.numel() == 1 and values.shape == torch.Size([]):
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
else:
if values.numel() != K:
pass
values = values.reshape((K,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling values when using a boolean mask has a couple of issues:

  1. The condition values.numel() == 1 and values.shape == torch.Size([]) is too restrictive. It only handles 0-dimensional scalar tensors, but not 1-element tensors like torch.tensor([5]). This will cause an error in the else branch if K != 1. The check should just be if values.numel() == 1:.
  2. The if values.numel() != K: pass is misleading. It suggests the case where the number of elements doesn't match is handled, but the following values.reshape((K,)) will fail with a RuntimeError. It's better to raise an explicit RuntimeError with a clear error message.
        if values.numel() == 1:
            values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
        else:
            if values.numel() != K:
                raise RuntimeError(
                    f"Number of values to put ({values.numel()}) does not match "
                    f"the number of elements to put ({K})"
                )
            values = values.reshape((K,))

Comment on lines 305 to 310
if values.numel() == 1 and values.shape == torch.Size([]):
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
else:
if values.numel() != K:
pass
values = values.reshape((K,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block has the same issues as in the index_put function:

  1. The scalar check values.numel() == 1 and values.shape == torch.Size([]) is too restrictive.
  2. The if values.numel() != K: pass is misleading and will lead to a crash if the number of elements don't match.

Please see the comment on the index_put function for a detailed explanation and suggested fix.

        if values.numel() == 1:
            values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
        else:
            if values.numel() != K:
                raise RuntimeError(
                    f"Number of values to put ({values.numel()}) does not match "
                    f"the number of elements to put ({K})"
                )
            values = values.reshape((K,))

Comment on lines 258 to 274
if len(indices) == 1 and indices[0].dtype == torch.bool:
mask = indices[0]

if mask.device != inp.device:
mask = mask.to(inp.device)

indices = list(torch.where(mask))

K = indices[0].numel()

if values.numel() == 1 and values.shape == torch.Size([]):
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
else:
if values.numel() != K:
pass
values = values.reshape((K,))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of logic for handling boolean indices is duplicated in the index_put_ function (lines 295-311). To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, consider extracting this logic into a private helper function. This would also mean that bug fixes, like the ones for values handling, only need to be applied in one place.

Comment on lines 270 to 273
else:
if values.numel() != K:
pass
values = values.reshape((K,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rewrite this as

Suggested change
else:
if values.numel() != K:
pass
values = values.reshape((K,))
elif values.numel() == K:
values = values.reshape((K,))

?

Comment on lines 307 to 310
else:
if values.numel() != K:
pass
values = values.reshape((K,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same

Suggested change
else:
if values.numel() != K:
pass
values = values.reshape((K,))
elif values.numel() == K:
values = values.reshape((K,))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@botbigeyes botbigeyes requested a review from sgjzfzzf November 4, 2025 07:19
sgjzfzzf
sgjzfzzf previously approved these changes Nov 5, 2025
sgjzfzzf
sgjzfzzf previously approved these changes Nov 6, 2025
@0x45f
Copy link
Collaborator

0x45f commented Nov 11, 2025

Add bool indices cases to benchmark

@0x45f
Copy link
Collaborator

0x45f commented Nov 17, 2025

plz fix op-test-quick-cpu ci

@botbigeyes
Copy link
Contributor Author

plz fix op-test-quick-cpu ci

But the tests passed on my local machine. Where could the issue be?
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants