-
Notifications
You must be signed in to change notification settings - Fork 152
Fix bool type indices #1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix bool type indices #1023
Conversation
Summary of ChangesHello @botbigeyes, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug where boolean type indices were not properly handled by the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes an issue with boolean type indices in index_put and index_put_. The approach of converting the boolean mask to indices using torch.where is correct. However, I've found a couple of issues in the implementation for handling the values tensor that could lead to runtime errors. I've also suggested a refactoring to reduce code duplication between the two functions. Overall, this is a good fix, but some improvements are needed for robustness and maintainability.
src/flag_gems/ops/index_put.py
Outdated
| if values.numel() == 1 and values.shape == torch.Size([]): | ||
| values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device) | ||
| else: | ||
| if values.numel() != K: | ||
| pass | ||
| values = values.reshape((K,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling values when using a boolean mask has a couple of issues:
- The condition
values.numel() == 1 and values.shape == torch.Size([])is too restrictive. It only handles 0-dimensional scalar tensors, but not 1-element tensors liketorch.tensor([5]). This will cause an error in theelsebranch ifK != 1. The check should just beif values.numel() == 1:. - The
if values.numel() != K: passis misleading. It suggests the case where the number of elements doesn't match is handled, but the followingvalues.reshape((K,))will fail with aRuntimeError. It's better to raise an explicitRuntimeErrorwith a clear error message.
if values.numel() == 1:
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
else:
if values.numel() != K:
raise RuntimeError(
f"Number of values to put ({values.numel()}) does not match "
f"the number of elements to put ({K})"
)
values = values.reshape((K,))
src/flag_gems/ops/index_put.py
Outdated
| if values.numel() == 1 and values.shape == torch.Size([]): | ||
| values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device) | ||
| else: | ||
| if values.numel() != K: | ||
| pass | ||
| values = values.reshape((K,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block has the same issues as in the index_put function:
- The scalar check
values.numel() == 1 and values.shape == torch.Size([])is too restrictive. - The
if values.numel() != K: passis misleading and will lead to a crash if the number of elements don't match.
Please see the comment on the index_put function for a detailed explanation and suggested fix.
if values.numel() == 1:
values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device)
else:
if values.numel() != K:
raise RuntimeError(
f"Number of values to put ({values.numel()}) does not match "
f"the number of elements to put ({K})"
)
values = values.reshape((K,))| if len(indices) == 1 and indices[0].dtype == torch.bool: | ||
| mask = indices[0] | ||
|
|
||
| if mask.device != inp.device: | ||
| mask = mask.to(inp.device) | ||
|
|
||
| indices = list(torch.where(mask)) | ||
|
|
||
| K = indices[0].numel() | ||
|
|
||
| if values.numel() == 1 and values.shape == torch.Size([]): | ||
| values = torch.full((K,), values.item(), dtype=inp.dtype, device=inp.device) | ||
| else: | ||
| if values.numel() != K: | ||
| pass | ||
| values = values.reshape((K,)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of logic for handling boolean indices is duplicated in the index_put_ function (lines 295-311). To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, consider extracting this logic into a private helper function. This would also mean that bug fixes, like the ones for values handling, only need to be applied in one place.
src/flag_gems/ops/index_put.py
Outdated
| else: | ||
| if values.numel() != K: | ||
| pass | ||
| values = values.reshape((K,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rewrite this as
| else: | |
| if values.numel() != K: | |
| pass | |
| values = values.reshape((K,)) | |
| elif values.numel() == K: | |
| values = values.reshape((K,)) |
?
src/flag_gems/ops/index_put.py
Outdated
| else: | ||
| if values.numel() != K: | ||
| pass | ||
| values = values.reshape((K,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same
| else: | |
| if values.numel() != K: | |
| pass | |
| values = values.reshape((K,)) | |
| elif values.numel() == K: | |
| values = values.reshape((K,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
|
Add bool indices cases to benchmark |
|
plz fix op-test-quick-cpu ci |

PR Category
Operator
Type of Change
Bug Fix
Description
The issue with unsupported bool type indices has been fixed.
Issue
Progress
Performance
accumulate=False:
accumulate=False and bool:
accumulate=True