diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp index 658ca60cc..5a9b19ab9 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp @@ -36,17 +36,58 @@ static inline int p_end(int size, int pad, int pooled_size, int stride) { return std::min((size + pad) / stride + 1, pooled_size); } -template +static inline bool can_use_int32_nhwc( + int64_t nbatch, int64_t channels, + int64_t height, int64_t width, + int64_t pooled_height, int64_t pooled_width, + int64_t in_stride_n, int64_t in_stride_c, + int64_t in_stride_h, int64_t in_stride_w) +{ + constexpr int64_t int_max = std::numeric_limits::max(); + + int64_t max_intra_batch = + (height ? (height - 1) * in_stride_h : 0) + + (width ? (width - 1) * in_stride_w : 0) + + (channels ? (channels - 1) * in_stride_c : 0); + + int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch; + + if (max_input_offset > int_max) return false; + + int64_t out_batch_stride = pooled_height * pooled_width * channels; + if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false; + + if (height * width > int_max) return false; + + return true; +} + +static inline bool can_use_int32_nchw( + int64_t nbatch, int64_t channels, + int64_t height, int64_t width, + int64_t pooled_height, int64_t pooled_width) { + int64_t hw = height * width; + return can_use_int32_nhwc( + nbatch, channels, height, width, + pooled_height, pooled_width, + channels * hw, // in_stride_n + hw, // in_stride_c + width, // in_stride_h + 1 // in_stride_w + ); +} + +template struct MaxPool2dKernelFunctor { void operator()(sycl::nd_item<2> item) const { auto desc = cfg_.get_item_desc(item); do { if (desc.glb_problem < cfg_.problem_) { - int outputIndex = desc.glb_problem; - int batch = outputIndex / stride_; - int plane, outputH, outputW; - int64_t load_offset, store_offset; + index_t outputIndex = desc.glb_problem; + index_t batch = outputIndex / stride_; + index_t plane, outputH, outputW; + index_t load_offset, store_offset; if constexpr (is_channels_last) { plane = outputIndex % numPlane_; outputH = outputIndex / numPlane_ / outputSizeW_ % outputSizeH_; @@ -62,19 +103,19 @@ struct MaxPool2dKernelFunctor { outputW; } scalar_t maxVal = at::numeric_limits::lower_bound(); - int maxIndex = -1; - int StartH = outputH * dH_ - padH_; - int StartW = outputW * dW_ - padW_; - int EndH = std::min(StartH + (kH_ - 1) * dilationH_ + 1, inputSizeH_); - int EndW = std::min(StartW + (kW_ - 1) * dilationW_ + 1, inputSizeW_); + index_t maxIndex = -1; + index_t StartH = outputH * dH_ - padH_; + index_t StartW = outputW * dW_ - padW_; + index_t EndH = std::min(StartH + (kH_ - 1) * dilationH_ + 1, inputSizeH_); + index_t EndW = std::min(StartW + (kW_ - 1) * dilationW_ + 1, inputSizeW_); while (StartH < 0) StartH += dilationH_; while (StartW < 0) StartW += dilationW_; #pragma unroll - for (int h = StartH; h < EndH; h += dilationH_) { + for (index_t h = StartH; h < EndH; h += dilationH_) { #pragma unroll - for (int w = StartW; w < EndW; w += dilationW_) { + for (index_t w = StartW; w < EndW; w += dilationW_) { if constexpr (is_channels_last) { load_offset = batch * inputSizeH_ * inputSizeW_ * numPlane_ + plane + h * inputSizeW_ * numPlane_ + w * numPlane_; @@ -98,11 +139,11 @@ struct MaxPool2dKernelFunctor { scalar_t* output, int64_t* indices, const scalar_t* input, - int numPlane, - int inputSizeH, - int inputSizeW, - int outputSizeH, - int outputSizeW, + index_t numPlane, + index_t inputSizeH, + index_t inputSizeW, + index_t outputSizeH, + index_t outputSizeW, int kH, int kW, int dH, @@ -111,7 +152,7 @@ struct MaxPool2dKernelFunctor { int padW, int dilationH, int dilationW, - int stride, + index_t stride, BatchKernelConfig cfg) : output_(output), indices_(indices), @@ -136,11 +177,11 @@ struct MaxPool2dKernelFunctor { scalar_t* output_; int64_t* indices_; const scalar_t* input_; - int numPlane_; - int inputSizeH_; - int inputSizeW_; - int outputSizeH_; - int outputSizeW_; + index_t numPlane_; + index_t inputSizeH_; + index_t inputSizeW_; + index_t outputSizeH_; + index_t outputSizeW_; int kH_; int kW_; int dH_; @@ -149,18 +190,18 @@ struct MaxPool2dKernelFunctor { int padW_; int dilationH_; int dilationW_; - int stride_; + index_t stride_; BatchKernelConfig cfg_; }; -template +template struct MaxPool2dChannelLastVec { void operator()(sycl::nd_item<1> item) const { for (auto outputIndex = item.get_global_linear_id(); outputIndex < numBatch_ * stride_ / vec_size; outputIndex += item.get_local_range(0) * item.get_group_range(0)) { - int batch = outputIndex / (stride_ / vec_size); - int plane, outputH, outputW; + index_t batch = outputIndex / (stride_ / vec_size); + index_t plane, outputH, outputW; int64_t load_offset, store_offset; plane = outputIndex % (numPlane_ / vec_size); outputH = @@ -177,16 +218,16 @@ struct MaxPool2dChannelLastVec { for (int i = 0; i < vec_size; i++) { maxIndex[i] = int64_t(-1); } - int StartH = outputH * dH_ - padH_; - int StartW = outputW * dW_ - padW_; - int EndH = std::min(StartH + (kH_ - 1) * dilationH_ + 1, inputSizeH_); - int EndW = std::min(StartW + (kW_ - 1) * dilationW_ + 1, inputSizeW_); + index_t StartH = outputH * dH_ - padH_; + index_t StartW = outputW * dW_ - padW_; + index_t EndH = std::min(StartH + (kH_ - 1) * dilationH_ + 1, inputSizeH_); + index_t EndW = std::min(StartW + (kW_ - 1) * dilationW_ + 1, inputSizeW_); while (StartH < 0) StartH += dilationH_; while (StartW < 0) StartW += dilationW_; - for (int h = StartH; h < EndH; h += dilationH_) { - for (int w = StartW; w < EndW; w += dilationW_) { + for (index_t h = StartH; h < EndH; h += dilationH_) { + for (index_t w = StartW; w < EndW; w += dilationW_) { load_offset = batch * inputSizeH_ * inputSizeW_ * numPlane_ / vec_size + plane + h * inputSizeW_ * numPlane_ / vec_size + w * numPlane_ / vec_size; @@ -212,12 +253,12 @@ struct MaxPool2dChannelLastVec { vec_t* output_vec, int64_t* indices, const vec_t* input_vec, - int numBatch, - int numPlane, - int inputSizeH, - int inputSizeW, - int outputSizeH, - int outputSizeW, + index_t numBatch, + index_t numPlane, + index_t inputSizeH, + index_t inputSizeW, + index_t outputSizeH, + index_t outputSizeW, int kH, int kW, int dH, @@ -226,7 +267,7 @@ struct MaxPool2dChannelLastVec { int padW, int dilationH, int dilationW, - int stride) + index_t stride) : output_vec_(output_vec), indices_(indices), input_vec_(input_vec), @@ -250,12 +291,12 @@ struct MaxPool2dChannelLastVec { vec_t* output_vec_; int64_t* indices_; const vec_t* input_vec_; - int numBatch_; - int numPlane_; - int inputSizeH_; - int inputSizeW_; - int outputSizeH_; - int outputSizeW_; + index_t numBatch_; + index_t numPlane_; + index_t inputSizeH_; + index_t inputSizeW_; + index_t outputSizeH_; + index_t outputSizeW_; int kH_; int kW_; int dH_; @@ -264,29 +305,29 @@ struct MaxPool2dChannelLastVec { int padW_; int dilationH_; int dilationW_; - int stride_; + index_t stride_; }; -template +template struct MaxPool2dBackwardKernelFunctor { void operator()(sycl::nd_item<2> item) const { auto desc = cfg_.get_item_desc(item); do { if (desc.glb_problem < cfg_.problem_) { - int batch = desc.glb_problem / out_n_stride_; - int outputIndex = desc.glb_problem; + index_t batch = desc.glb_problem / out_n_stride_; + index_t outputIndex = desc.glb_problem; if constexpr (is_channels_last) { - int plane = outputIndex % numPlane_; + index_t plane = outputIndex % numPlane_; int64_t index = indices_[outputIndex]; - int64_t gI_offset = batch * in_n_stride_ + plane + index * numPlane_; + index_t gI_offset = batch * in_n_stride_ + plane + index * numPlane_; atomicAdd( (sycl_global_ptr)&gradInput_[gI_offset], gradOutput_[outputIndex]); } else { - int plane = outputIndex / out_cf_c_stride_ % numPlane_; + index_t plane = outputIndex / out_cf_c_stride_ % numPlane_; int64_t index = indices_[outputIndex]; - int64_t gI_offset = + index_t gI_offset = batch * in_n_stride_ + plane * in_cf_c_stride_ + index; atomicAdd( (sycl_global_ptr)&gradInput_[gI_offset], @@ -299,16 +340,16 @@ struct MaxPool2dBackwardKernelFunctor { scalar_t* gradInput, const scalar_t* gradOutput, const int64_t* indices, - int numPlane, - int gradInputSizeH, - int gradInputSizeW, - int gradOutputSizeH, - int gradOutputSizeW, + index_t numPlane, + index_t gradInputSizeH, + index_t gradInputSizeW, + index_t gradOutputSizeH, + index_t gradOutputSizeW, int64_t gradOutputSize, - int out_cf_c_stride, - int in_cf_c_stride, - int out_n_stride, - int in_n_stride, + index_t out_cf_c_stride, + index_t in_cf_c_stride, + index_t out_n_stride, + index_t in_n_stride, BatchKernelConfig cfg) : gradInput_(gradInput), gradOutput_(gradOutput), @@ -329,29 +370,29 @@ struct MaxPool2dBackwardKernelFunctor { scalar_t* gradInput_; const scalar_t* gradOutput_; const int64_t* indices_; - int numPlane_; - int gradInputSizeH_; - int gradInputSizeW_; - int gradOutputSizeH_; - int gradOutputSizeW_; + index_t numPlane_; + index_t gradInputSizeH_; + index_t gradInputSizeW_; + index_t gradOutputSizeH_; + index_t gradOutputSizeW_; int64_t gradOutputSize_; - int out_cf_c_stride_; - int in_cf_c_stride_; - int out_n_stride_; - int in_n_stride_; + index_t out_cf_c_stride_; + index_t in_cf_c_stride_; + index_t out_n_stride_; + index_t in_n_stride_; BatchKernelConfig cfg_; }; -template +template struct MaxPool2dBackwardDeterministicKernelFunctor { void operator()(sycl::nd_item<2> item) const { auto desc = cfg_.get_item_desc(item); do { if (desc.glb_problem < cfg_.problem_) { - int inputIndex = desc.glb_problem; - int batch = inputIndex / in_n_stride_; - int plane; - int64_t input_hw_index; + index_t inputIndex = desc.glb_problem; + index_t batch = inputIndex / in_n_stride_; + index_t plane; + index_t input_hw_index; if constexpr (is_channels_last) { plane = inputIndex % numPlane_; input_hw_index = ((inputIndex % in_n_stride_) - plane) / numPlane_; @@ -359,8 +400,8 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { plane = inputIndex / in_cf_c_stride_ % numPlane_; input_hw_index = ((inputIndex % in_n_stride_)) % in_cf_c_stride_; } - int inputW = input_hw_index % gradInputSizeW_; - int inputH = input_hw_index / gradInputSizeW_; + index_t inputW = input_hw_index % gradInputSizeW_; + index_t inputH = input_hw_index / gradInputSizeW_; int phstart = p_start(inputH, pad_h_, kernel_h_, dilation_h_, stride_h_); int phend = p_end(inputH, pad_h_, gradOutputSizeH_, stride_h_); @@ -369,7 +410,7 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_); scalar_t grad = 0; if constexpr (is_channels_last) { - int offset = batch * out_n_stride_ + plane; + index_t offset = batch * out_n_stride_ + plane; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { if (indices_[offset + (ph * gradOutputSizeW_ + pw) * numPlane_] == @@ -381,7 +422,7 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { } } } else { - int offset = batch * out_n_stride_ + plane * out_cf_c_stride_; + index_t offset = batch * out_n_stride_ + plane * out_cf_c_stride_; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { if (indices_[offset + ph * gradOutputSizeW_ + pw] == @@ -400,16 +441,16 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { scalar_t* gradInput, const scalar_t* gradOutput, const int64_t* indices, - int numPlane, - int gradInputSizeH, - int gradInputSizeW, - int gradOutputSizeH, - int gradOutputSizeW, + index_t numPlane, + index_t gradInputSizeH, + index_t gradInputSizeW, + index_t gradOutputSizeH, + index_t gradOutputSizeW, int64_t gradInputSize, - int out_cf_c_stride, - int in_cf_c_stride, - int out_n_stride, - int in_n_stride, + index_t out_cf_c_stride, + index_t in_cf_c_stride, + index_t out_n_stride, + index_t in_n_stride, int kernel_h, int kernel_w, int stride_h, @@ -446,16 +487,16 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { scalar_t* gradInput_; const scalar_t* gradOutput_; const int64_t* indices_; - int numPlane_; - int gradInputSizeH_; - int gradInputSizeW_; - int gradOutputSizeH_; - int gradOutputSizeW_; + index_t numPlane_; + index_t gradInputSizeH_; + index_t gradInputSizeW_; + index_t gradOutputSizeH_; + index_t gradOutputSizeW_; int64_t gradInputSize_; - int out_cf_c_stride_; - int in_cf_c_stride_; - int out_n_stride_; - int in_n_stride_; + index_t out_cf_c_stride_; + index_t in_cf_c_stride_; + index_t out_n_stride_; + index_t in_n_stride_; int kernel_h_; int kernel_w_; int stride_h_; @@ -468,14 +509,14 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { }; -template +template struct MaxPool2dBackwardChannelLastVec { void operator()(sycl::nd_item<1> item) const { for (auto inputIndex = item.get_global_linear_id(); inputIndex < gradInputSize_ / vec_size; inputIndex += item.get_local_range(0) * item.get_group_range(0)) { - int batch = inputIndex / (in_n_stride_ / vec_size); - int plane; + index_t batch = inputIndex / (in_n_stride_ / vec_size); + index_t plane; int64_t input_hw_index; plane = inputIndex % (numPlane_ / vec_size); @@ -483,8 +524,8 @@ struct MaxPool2dBackwardChannelLastVec { input_hw_index = ((inputIndex % (in_n_stride_ / vec_size)) - plane) / (numPlane_ / vec_size); - int inputW = input_hw_index % gradInputSizeW_; - int inputH = input_hw_index / gradInputSizeW_; + index_t inputW = input_hw_index % gradInputSizeW_; + index_t inputH = input_hw_index / gradInputSizeW_; int phstart = p_start(inputH, pad_h_, kernel_h_, dilation_h_, stride_h_); int phend = p_end(inputH, pad_h_, gradOutputSizeH_, stride_h_); int pwstart = p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_); @@ -519,14 +560,14 @@ struct MaxPool2dBackwardChannelLastVec { vec_t* gradInput, const vec_t* gradOutput, const int64_t* indices, - int numPlane, - int gradInputSizeH, - int gradInputSizeW, - int gradOutputSizeH, - int gradOutputSizeW, + index_t numPlane, + index_t gradInputSizeH, + index_t gradInputSizeW, + index_t gradOutputSizeH, + index_t gradOutputSizeW, int64_t gradInputSize, - int out_n_stride, - int in_n_stride, + index_t out_n_stride, + index_t in_n_stride, int kernel_h, int kernel_w, int stride_h, @@ -559,14 +600,14 @@ struct MaxPool2dBackwardChannelLastVec { vec_t* gradInput_; const vec_t* gradOutput_; const int64_t* indices_; - int numPlane_; - int gradInputSizeH_; - int gradInputSizeW_; - int gradOutputSizeH_; - int gradOutputSizeW_; + index_t numPlane_; + index_t gradInputSizeH_; + index_t gradInputSizeW_; + index_t gradOutputSizeH_; + index_t gradOutputSizeW_; int64_t gradInputSize_; - int out_n_stride_; - int in_n_stride_; + index_t out_n_stride_; + index_t in_n_stride_; int kernel_h_; int kernel_w_; int stride_h_; @@ -581,6 +622,7 @@ struct MaxPool2dBackwardChannelLastVec { #define LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( \ scalar_t, \ vec_size, \ + index_t, \ num_wg, \ wg_size, \ queue, \ @@ -603,43 +645,44 @@ struct MaxPool2dBackwardChannelLastVec { dilationW, \ stride) \ { \ - using vec_t = memory::aligned_vector; \ - vec_t* output_vec = reinterpret_cast(output); \ - const vec_t* input_vec = reinterpret_cast(input); \ - auto kfn = MaxPool2dChannelLastVec( \ - output_vec, \ - indices, \ - input_vec, \ - numBatch, \ - numPlane, \ - inputSizeH, \ - inputSizeW, \ - outputSizeH, \ - outputSizeW, \ - kH, \ - kW, \ - dH, \ - dW, \ - padH, \ - padW, \ - dilationH, \ - dilationW, \ - stride); \ - sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ + using vec_t = memory::aligned_vector; \ + vec_t* output_vec = reinterpret_cast(output); \ + const vec_t* input_vec = reinterpret_cast(input); \ + auto kfn = \ + MaxPool2dChannelLastVec( \ + output_vec, \ + indices, \ + input_vec, \ + numBatch, \ + numPlane, \ + inputSizeH, \ + inputSizeW, \ + outputSizeH, \ + outputSizeW, \ + kH, \ + kW, \ + dH, \ + dW, \ + padH, \ + padW, \ + dilationH, \ + dilationW, \ + stride); \ + sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ } -template +template void launch_max_pool2d_kernel( scalar_t* output, int64_t* indices, const scalar_t* input, - int numBatch, - int numPlane, - int inputSizeH, - int inputSizeW, - int outputSizeH, - int outputSizeW, + index_t numBatch, + index_t numPlane, + index_t inputSizeH, + index_t inputSizeW, + index_t outputSizeH, + index_t outputSizeW, int kH, int kW, int dH, @@ -649,11 +692,11 @@ void launch_max_pool2d_kernel( int dilationH, int dilationW) { auto& queue = at::xpu::getCurrentSYCLQueue(); - int outputSize = numBatch * numPlane * outputSizeH * outputSizeW; - int stride = numPlane * outputSizeH * outputSizeW; + int64_t outputSize = static_cast(numBatch) * static_cast(numPlane) * static_cast(outputSizeH) * static_cast(outputSizeW); + int64_t stride = static_cast(numPlane) * static_cast(outputSizeH) * static_cast(outputSizeW); int vec_size = 1; int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); - int num_sub_wg; + int64_t num_sub_wg; auto wg_size = syclDeviceMaxWorkGroupSize(); int64_t num_wg; if constexpr (is_channels_last) { @@ -676,6 +719,7 @@ void launch_max_pool2d_kernel( LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( scalar_t, 8, + index_t, num_wg, wg_size, queue, @@ -702,6 +746,7 @@ void launch_max_pool2d_kernel( LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( scalar_t, 4, + index_t, num_wg, wg_size, queue, @@ -728,6 +773,7 @@ void launch_max_pool2d_kernel( LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( scalar_t, 2, + index_t, num_wg, wg_size, queue, @@ -754,7 +800,7 @@ void launch_max_pool2d_kernel( break; }; } - using KernelClass = MaxPool2dKernelFunctor; + using KernelClass = MaxPool2dKernelFunctor; BatchKernelConfig cfg = BatchKernelConfig::make_config( 1, outputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); @@ -783,6 +829,7 @@ void launch_max_pool2d_kernel( #define LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( \ scalar_t, \ vec_size, \ + index_t, \ num_wg, \ wg_size, \ queue, \ @@ -806,43 +853,44 @@ void launch_max_pool2d_kernel( dilation_h, \ dilation_w) \ { \ - using vec_t = memory::aligned_vector; \ - const vec_t* grad_output_vec = reinterpret_cast(gradOutput); \ - vec_t* grad_input_vec = reinterpret_cast(gradInput); \ - auto kfn = MaxPool2dBackwardChannelLastVec( \ - grad_input_vec, \ - grad_output_vec, \ - indices, \ - numPlane, \ - gradInputSizeH, \ - gradInputSizeW, \ - gradOutputSizeH, \ - gradOutputSizeW, \ - gradInputSize, \ - out_n_stride, \ - in_n_stride, \ - kernel_h, \ - kernel_w, \ - stride_h, \ - stride_w, \ - pad_h, \ - pad_w, \ - dilation_h, \ - dilation_w); \ - sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ + using vec_t = memory::aligned_vector; \ + const vec_t* grad_output_vec = reinterpret_cast(gradOutput); \ + vec_t* grad_input_vec = reinterpret_cast(gradInput); \ + auto kfn = \ + MaxPool2dBackwardChannelLastVec( \ + grad_input_vec, \ + grad_output_vec, \ + indices, \ + numPlane, \ + gradInputSizeH, \ + gradInputSizeW, \ + gradOutputSizeH, \ + gradOutputSizeW, \ + gradInputSize, \ + out_n_stride, \ + in_n_stride, \ + kernel_h, \ + kernel_w, \ + stride_h, \ + stride_w, \ + pad_h, \ + pad_w, \ + dilation_h, \ + dilation_w); \ + sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ } -template +template void launch_max_pool2d_backward_kernel( scalar_t* gradInput, const scalar_t* gradOutput, const int64_t* indices, - int numBatch, - int numPlane, - int gradInputSizeH, - int gradInputSizeW, - int gradOutputSizeH, - int gradOutputSizeW, + index_t numBatch, + index_t numPlane, + index_t gradInputSizeH, + index_t gradInputSizeW, + index_t gradOutputSizeH, + index_t gradOutputSizeW, int kernel_h, int kernel_w, int stride_h, @@ -895,6 +943,7 @@ void launch_max_pool2d_backward_kernel( LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( scalar_t, 8, + index_t, num_wg, wg_size, queue, @@ -922,6 +971,7 @@ void launch_max_pool2d_backward_kernel( LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( scalar_t, 4, + index_t, num_wg, wg_size, queue, @@ -949,6 +999,7 @@ void launch_max_pool2d_backward_kernel( LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( scalar_t, 2, + index_t, num_wg, wg_size, queue, @@ -977,7 +1028,7 @@ void launch_max_pool2d_backward_kernel( }; } using KernelClass = - MaxPool2dBackwardDeterministicKernelFunctor; + MaxPool2dBackwardDeterministicKernelFunctor; BatchKernelConfig cfg = BatchKernelConfig::make_config( 1, gradInputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); cfg.template build(); @@ -1009,7 +1060,7 @@ void launch_max_pool2d_backward_kernel( int64_t gradOutputSize = numBatch * numPlane * gradOutputSizeH * gradOutputSizeW; using KernelClass = - MaxPool2dBackwardKernelFunctor; + MaxPool2dBackwardKernelFunctor; BatchKernelConfig cfg = BatchKernelConfig::make_config( 1, gradOutputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); cfg.template build(); @@ -1052,9 +1103,12 @@ void max_pool2d_with_indices_kernel( return; } - auto smf = input_.suggest_memory_format(); + auto memory_format = input_.suggest_memory_format(); + if (memory_format == MemoryFormat::Contiguous && input_.numel() > static_cast(std::numeric_limits::max())) { + memory_format = MemoryFormat::ChannelsLast; + } - Tensor input = input_.contiguous(smf); + Tensor input = input_.contiguous(memory_format); const int kH = safe_downcast(kernel_size[0]); const int kW = kernel_size.size() == 1 @@ -1083,49 +1137,102 @@ void max_pool2d_with_indices_kernel( const int64_t outputHeight = output.size(-2); const int64_t outputWidth = output.size(-1); + const int64_t in_stride_n = input_.ndimension() == 4 ? input.stride(-4) : 0; + const int64_t in_stride_c = input.stride(-3); + const int64_t in_stride_h = input.stride(-2); + const int64_t in_stride_w = input.stride(-1); + AT_DISPATCH_FLOATING_TYPES_AND2( kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] { - switch (smf) { + switch (memory_format) { case MemoryFormat::ChannelsLast: { - launch_max_pool2d_kernel( - output.mutable_data_ptr(), - indices.mutable_data_ptr(), - input.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); + bool use_int32 = can_use_int32_nhwc( + nbatch, nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + in_stride_n, in_stride_c, in_stride_h, in_stride_w); + if (use_int32) { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + static_cast(nbatch), + static_cast(nInputPlane), + static_cast(inputHeight), + static_cast(inputWidth), + static_cast(outputHeight), + static_cast(outputWidth), + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } else { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } break; } case MemoryFormat::Contiguous: { - launch_max_pool2d_kernel( - output.mutable_data_ptr(), - indices.mutable_data_ptr(), - input.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); + bool use_int32 = can_use_int32_nchw( + nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth); + if (use_int32) { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + static_cast(nbatch), + static_cast(nInputPlane), + static_cast(inputHeight), + static_cast(inputWidth), + static_cast(outputHeight), + static_cast(outputWidth), + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } else { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } break; } default: @@ -1154,12 +1261,15 @@ void max_pool2d_with_indices_backward_kernel( checkAllSameGPU( __func__, {gradInput_arg, gradOutput_arg, input_arg, indices_arg}); - auto smf = input_.suggest_memory_format(); + auto memory_format = input_.suggest_memory_format(); + if (memory_format == MemoryFormat::Contiguous && input_.numel() > static_cast(std::numeric_limits::max())) { + memory_format = MemoryFormat::ChannelsLast; + } Tensor input, gradOutput, indices; - input = input_.contiguous(smf); - gradOutput = gradOutput_.contiguous(smf); - indices = indices_.contiguous(smf); + input = input_.contiguous(memory_format); + gradOutput = gradOutput_.contiguous(memory_format); + indices = indices_.contiguous(memory_format); gradInput.zero_(); const int kH = safe_downcast(kernel_size[0]); @@ -1188,53 +1298,109 @@ void max_pool2d_with_indices_backward_kernel( inputHeight, kH, padH, dH, dilationH, ceil_mode); int64_t outputWidth = pooling_output_shape( inputWidth, kW, padW, dW, dilationW, ceil_mode); + + const int64_t in_stride_n = input.ndimension() == 4 ? input.stride(-4) : 0; + const int64_t in_stride_c = input.stride(-3); + const int64_t in_stride_h = input.stride(-2); + const int64_t in_stride_w = input.stride(-1); + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, gradOutput.scalar_type(), "max_pool2d_backward_xpu", [&] { - switch (smf) { - case at::MemoryFormat::ChannelsLast: - launch_max_pool2d_backward_kernel( - gradInput.mutable_data_ptr(), - gradOutput.const_data_ptr(), - indices.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); + switch (memory_format) { + case at::MemoryFormat::ChannelsLast: { + bool use_int32 = can_use_int32_nhwc( + nbatch, nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + in_stride_n, in_stride_c, in_stride_h, in_stride_w); + if (use_int32) { + launch_max_pool2d_backward_kernel( + gradInput.mutable_data_ptr(), + gradOutput.const_data_ptr(), + indices.const_data_ptr(), + static_cast(nbatch), + static_cast(nInputPlane), + static_cast(inputHeight), + static_cast(inputWidth), + static_cast(outputHeight), + static_cast(outputWidth), + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } else { + launch_max_pool2d_backward_kernel( + gradInput.mutable_data_ptr(), + gradOutput.const_data_ptr(), + indices.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } break; - case at::MemoryFormat::Contiguous: - launch_max_pool2d_backward_kernel( - gradInput.mutable_data_ptr(), - gradOutput.const_data_ptr(), - indices.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); + } + case at::MemoryFormat::Contiguous: { + bool use_int32 = can_use_int32_nchw( + nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth); + if (use_int32) { + launch_max_pool2d_backward_kernel( + gradInput.mutable_data_ptr(), + gradOutput.const_data_ptr(), + indices.const_data_ptr(), + static_cast(nbatch), + static_cast(nInputPlane), + static_cast(inputHeight), + static_cast(inputWidth), + static_cast(outputHeight), + static_cast(outputWidth), + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } else { + launch_max_pool2d_backward_kernel( + gradInput.mutable_data_ptr(), + gradOutput.const_data_ptr(), + indices.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + } break; + } default: TORCH_CHECK( false,