-
Notifications
You must be signed in to change notification settings - Fork 40
Fix deterministic indexing with broadcast #1705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the index_put
implementation on XPU by ensuring deterministic indexing, centralizing shape logic, and bolstering test coverage.
- Introduce a
valsShape
helper to compute expanded-value shapes. - Extend
computeLinearIndex
andmakeLinearIndex
to returndims_before
anddims_indexed
. - Simplify value expansion in
index_put_deterministic_kernel
viavalsShape
. - Add new deterministic tests for
index_put
with optional tensors and shape-mismatch checks.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
File | Description |
---|---|
test/xpu/test_indexing_xpu.py | Added deterministic index_put tests, including 0D/1D/2D values and mismatch assertions |
src/ATen/native/xpu/sycl/IndexingUtils.h | Extended computeLinearIndex /makeLinearIndex to return two new dimension counts |
src/ATen/native/xpu/sycl/Indexing.cpp | Added valsShape helper and replaced manual expansion in the deterministic kernel |
Comments suppressed due to low confidence (2)
test/xpu/test_indexing_xpu.py:18
- [nitpick] The helper names
func
andfunc1
are ambiguous—consider renaming them to clearly reflect their purpose (e.g.,index_put_with_guard
andsimple_index_put
).
def func(x, i, v):
test/xpu/test_indexing_xpu.py:35
- [nitpick] Variable
values2d
does not match thevalue0d
/value1d
pattern—rename tovalue2d
for consistency.
values2d = torch.randn(n, 1)
out_cpu = func(t, indices, value1d) | ||
t = torch.zeros(2, 3, 4) | ||
ind = torch.tensor([0, 1]) | ||
val = torch.randn(6, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests use different error regexes ("shape mismatch"
vs "must match"
) for CPU vs XPU—consider unifying the expected message or adding a brief comment explaining the discrepancy to prevent brittleness.
val = torch.randn(6, 2) | |
val = torch.randn(6, 2) | |
# The error messages differ between CPU ("shape mismatch") and XPU ("must match") | |
# due to implementation-specific differences in error handling. |
Copilot uses AI. Check for mistakes.
@@ -32,17 +32,38 @@ def func(x, i, v): | |||
indices_dev = indices.to(device) | |||
value0d = torch.tensor(10.0) | |||
value1d = torch.tensor([1.0, 2.0]) | |||
values2d = torch.randn(n, 1) | |||
|
|||
for val in (value0d, value1d, values2d): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The loop reuses the same t_dev
/t
across multiple func
calls, mutating them cumulatively—consider reinitializing t
and t_dev
inside the loop to make each subtest independent.
for val in (value0d, value1d, values2d): | |
for val in (value0d, value1d, values2d): | |
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) | |
t_dev = t.to(device) |
Copilot uses AI. Check for mistakes.
@@ -70,8 +68,10 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex( | |||
// are not being index. | |||
Tensor linearIndex; | |||
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new tuple fields dims_before
and dims_indexed
would benefit from a brief inline comment explaining their meaning and relationship to the indexing algorithm.
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0; | |
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0; | |
// `dims_before` counts the number of dimensions before the indexed dimensions. | |
// `dims_indexed` counts the number of dimensions that are being indexed. |
Copilot uses AI. Check for mistakes.
@@ -609,6 +609,21 @@ void index_put_kernel( | |||
} | |||
} | |||
|
|||
DimVector valsShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider marking valsShape
as static inline
or moving its declaration to the header with a doc comment, so its purpose and usage are clearer and the compiler can inline it across translation units.
Copilot uses AI. Check for mistakes.
@chunhuanMeng Pls remove |
done |
Introduces enhancements to the
index_put
implementation for XPU tensors, focusing on deterministic behavior, improved shape handling, and expanded test coverage. Key changes include adding new helper functions, extending themakeLinearIndex
andcomputeLinearIndex
methods, and updating the associated test suite.Enhancements to
index_put
Implementation:New Helper Function for Shape Handling:
valsShape
to compute the target shape for expanded values duringindex_put
operations. This simplifies and centralizes shape manipulation logic. (src/ATen/native/xpu/sycl/Indexing.cpp
)Extended
makeLinearIndex
andcomputeLinearIndex
:dims_before
anddims_indexed
to track dimensions before and during indexing. These are now returned as part of the tuple fromcomputeLinearIndex
and propagated throughmakeLinearIndex
. (src/ATen/native/xpu/sycl/IndexingUtils.h
)Simplified Value Expansion in
index_put_deterministic_kernel
:valsShape
. This makes the code more concise and reduces duplication. (src/ATen/native/xpu/sycl/Indexing.cpp
)Test Suite Enhancements:
test_index_put_deterministic_with_optional_tensors
, to validate deterministic behavior ofindex_put
with various tensor shapes and scenarios. This includes checks for shape mismatches and proper handling of 0D, 1D, and 2D values. (test/xpu/test_indexing_xpu.py
)These changes collectively improve the robustness, maintainability, and test coverage of the
index_put
functionality for XPU tensors.