-
Notifications
You must be signed in to change notification settings - Fork 62
[XPU][Fix] Fix large maxpool index #2362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes integer overflow issues in XPU max pooling operations when handling large tensor sizes. The fix addresses a segmentation fault that occurred when output tensors exceeded INT_MAX by introducing index type templating and automatic memory format selection.
Key Changes:
- Introduced
index_ttemplate parameter (int32_t or int64_t) for kernel functors and functions to handle both small and large tensor sizes - Added validation functions
can_use_int32_nhwcandcan_use_int32_nchwto determine when int32 is safe to use - Automatically switches to ChannelsLast memory format when contiguous format would exceed int32 limits
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
EikanWang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please address the copilot's comments.
| const vec_t* grad_output_vec = reinterpret_cast<const vec_t*>(gradOutput); \ | ||
| vec_t* grad_input_vec = reinterpret_cast<vec_t*>(gradInput); \ | ||
| auto kfn = MaxPool2dBackwardChannelLastVec<scalar_t, vec_t, vec_size>( \ | ||
| auto kfn = MaxPool2dBackwardChannelLastVec<scalar_t, vec_t, vec_size, index_t>( \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls. fix the code style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Added in 474ac9d
Performance outliers, please check!
|
This is to fix the pytorch/pytorch#167253 . It does the following:
index_tinstead of int and dispatch kernels accordingly. (follows [CUDA] Large max pool fix pytorch/pytorch#167427)num_wgtoindex_tto avoid overflow.Details
Test case:
It will throw the error:
[MaxPool2d] Input shape: [74, 32, 30090, 81] output: [74, 32, 30090, 40] [MaxPool2d] Strides: n=77993280 c=1 h=2592 w=32 [MaxPool2d] Memory format: ChannelsLast [MaxPool2d Forward] ChannelsLast path: numBatch=74 numPlane=32 inputH=30090 inputW=81 outputH=30090 outputW=40 index_t=int64 [MaxPool2d Forward] Using vec_size=1 num_wg=-72057583935701024 Segmentation fault from GPU at 0xff00000c04e33000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting. Segmentation fault from GPU at 0xff00000c04e33000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting. Abort was called at 279 line in file: ./shared/source/os_interface/linux/drm_neo.cpp [1] 77805 IOT instruction (core dumped) pythonFrom the above code, the
num_wgis overflow to negative, thus caused segfault.