diff --git "a/BinaryBackwardKernel\345\210\206\346\236\220\344\270\216\344\274\230\345\214\226.pdf" "b/BinaryBackwardKernel\345\210\206\346\236\220\344\270\216\344\274\230\345\214\226.pdf" new file mode 100644 index 00000000..adbce5ad Binary files /dev/null and "b/BinaryBackwardKernel\345\210\206\346\236\220\344\270\216\344\274\230\345\214\226.pdf" differ diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 88e51659..9d380626 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -32,6 +32,46 @@ __device__ inline int64_t CalcOffset(int64_t idx, int ndim, const int64_t *strid return offset; } +//更新版的offset计算 +__device__ __forceinline__ +int64_t CalcBOffset_lvl0(uint64_t idx, + int eff_ndim_b, + const int64_t* __restrict__ eff_out_strides_b, + const int64_t* __restrict__ eff_b_strides_b, + const int64_t* __restrict__ eff_shape_b) { + int64_t off = 0; + #pragma unroll + for (int j = 0; j < eff_ndim_b; ++j) { + int64_t out_index = (idx / eff_out_strides_b[j]) % eff_shape_b[j]; + off += out_index * eff_b_strides_b[j]; + } + return off; // 若 eff_ndim_b==0,off 恒 0 +} + +template +__global__ void BinaryBackwardKernelNoReduce( + T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, + int ndim, size_t num_elements, + const int64_t *a_strides, const int64_t *a_shape, + const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides, + const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b, + const T *grad_output, const T *input_a, const T *input_b) +{ + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + + //考虑到a和b都没有进行广播,那么a和b和out_put就都相同 + const int64_t a_offset = idx; + const int64_t b_offset = idx; + + const T a_val = input_a ? input_a[a_offset] : T(0); + const T b_val = input_b ? input_b[b_offset] : T(0); + + output_a[a_offset] = Mul(grad_output[idx], fn_a(a_val, b_val)); + const T db = common::cuda::Cast(Mul(grad_output[idx], fn_b(a_val, b_val))); + output_b[b_offset] = db; +} + template __global__ void BinaryForwardKernel(T *output, Func fn, int ndim, const int64_t *a_strides, const int64_t *a_shape, const int64_t *b_strides, const int64_t *b_shape, const int64_t *out_strides, @@ -148,10 +188,13 @@ __global__ void UnaryBackwardKernel(T *output, Func fn, size_t num_elements, siz // Backward kernel for binary operators // TODO(lzm): determining and passing b_is_broadcasted from the caller; optimize further template -__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, int ndim, size_t num_elements, - const int64_t *a_strides, const int64_t *a_shape, const int64_t *b_strides, - const int64_t *b_shape, const int64_t *out_strides, const T *grad_output, - const T *input_a, const T *input_b) { +__global__ void BinaryBackwardKernel_opt( + T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, + int ndim, size_t num_elements, + const int64_t *a_strides, const int64_t *a_shape, + const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides, + const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b, + const T *grad_output, const T *input_a, const T *input_b) { extern __shared__ char shared_memory[]; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -168,8 +211,8 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB float grad_val = 0.0f; if (in_bounds) { - a_offset = CalcOffset(idx, ndim, a_strides, a_shape, out_strides); - b_offset = CalcOffset(idx, ndim, b_strides, b_shape, out_strides); + a_offset = idx; + b_offset = CalcBOffset_lvl0(idx, eff_ndim_b, eff_out_strides_b, eff_b_strides_b, eff_shape_b); a_val = input_a ? input_a[a_offset] : T(0); b_val = input_b ? input_b[b_offset] : T(0); output_a[a_offset] = Mul(grad_output[idx], fn_a(a_val, b_val)); @@ -213,64 +256,94 @@ struct SharedElem { int64_t offset; float grad; }; + // NOTE(dcj): Specialized BinaryBackwardKernel for low-precision types (__half / bfloat16) template -__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, int ndim, size_t num_elements, - size_t b_num_elements, const int64_t *a_strides, const int64_t *a_shape, - const int64_t *b_strides, const int64_t *b_shape, const int64_t *out_strides, - const T *grad_output, const T *input_a, const T *input_b, bool fast_atomics) { - extern __shared__ char shared_memory[]; - // Shared memory stores b_offset and grad_val - SharedElem *smem = reinterpret_cast(shared_memory); +__global__ void BinaryBackwardKernelNoReduce( int ndim, size_t num_elements, + const int64_t *a_strides, const int64_t *a_shape, + const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides, + const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b, + const T *grad_output, const T *input_a, const T *input_b, bool fast_atomics) + { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + + //考虑到a和b都没有进行广播,那么a和b和out_put就都相同 + const int64_t a_offset = idx; + const int64_t b_offset = idx; + + const T a_val = input_a ? input_a[a_offset] : T(0); + const T b_val = input_b ? input_b[b_offset] : T(0); + + output_a[a_offset] = Mul(grad_output[idx], fn_a(a_val, b_val)); + const T db = common::cuda::Cast(Mul(grad_output[idx], fn_b(a_val, b_val))); + output_b[b_offset] = db; + } + +template +__global__ void BinaryBackwardKernel_opt( T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, + int ndim, size_t num_elements, + const int64_t *a_strides, const int64_t *a_shape, + const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides, + const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b, + const T *grad_output, const T *input_a, const T *input_b) { + extern __shared__ char shared_memory[]; const int tid = threadIdx.x; - const int block_threads = blockDim.x; - const int global_idx = blockIdx.x * blockDim.x + tid; - bool in_bounds = (global_idx < num_elements); + const int warp_id = tid / 32; + const int lane_id = tid % 32; + + using WarpReduce = cub::WarpReduce; + WarpReduce::TempStorage *temp_storage = reinterpret_cast(shared_memory); + + size_t idx = blockIdx.x * blockDim.x + tid; + bool in_bounds = (idx < num_elements); - // Each thread calculates its own a_offset and b_offset int64_t a_offset = 0, b_offset = 0; - float grad_val = 0.0f; T a_val = T(0), b_val = T(0); + float grad_val = 0.0f; if (in_bounds) { - a_offset = CalcOffset(global_idx, ndim, a_strides, a_shape, out_strides); - b_offset = CalcOffset(global_idx, ndim, b_strides, b_shape, out_strides); - + a_offset = idx; + b_offset = CalcBOffset_lvl0(idx, eff_ndim_b, eff_out_strides_b, eff_b_strides_b, eff_shape_b); a_val = input_a ? input_a[a_offset] : T(0); b_val = input_b ? input_b[b_offset] : T(0); - - // Compute gradient contribution for output_a - output_a[a_offset] = Mul(grad_output[global_idx], fn_a(a_val, b_val)); - // Store gradient contribution for output_b in float for accumulation - grad_val = common::cuda::Cast(Mul(grad_output[global_idx], fn_b(a_val, b_val))); + output_a[a_offset] = Mul(grad_output[idx], fn_a(a_val, b_val)); + grad_val = common::cuda::Cast(Mul(grad_output[idx], fn_b(a_val, b_val))); } - // Write each thread's b_offset and grad_val into shared memory - smem[tid].offset = in_bounds ? b_offset : -1; - smem[tid].grad = grad_val; + unsigned active_mask = __ballot_sync(0xFFFFFFFF, in_bounds); + if (!active_mask) { + return; + } - __syncthreads(); + int leader = __ffs(active_mask) - 1; + int64_t common_offset = __shfl_sync(active_mask, b_offset, leader); - // Block-level reduction: threads check if they can accumulate - for (int stride = 1; stride < block_threads; stride *= 2) { - __syncthreads(); - if (tid % (2 * stride) == 0 && (tid + stride) < block_threads) { - if (smem[tid].offset == smem[tid + stride].offset && smem[tid].offset != -1) { - smem[tid].grad += smem[tid + stride].grad; - smem[tid + stride].offset = -1; - } + // Check if all active threads share common b_offset + bool warp_uniform = true; + for (int i = 0; i < 32; ++i) { + if (!(active_mask & (1 << i))) { + continue; + } + int64_t offset_i = __shfl_sync(active_mask, b_offset, i); + if (offset_i != common_offset) { + warp_uniform = false; + break; } } - __syncthreads(); - // Write final result back to global memory - if (in_bounds && smem[tid].offset != -1) { - fastAtomicAdd(output_b, smem[tid].offset, b_num_elements, common::cuda::Cast(smem[tid].grad), - fast_atomics); + if (warp_uniform) { + float reduced = WarpReduce(temp_storage[warp_id]).Sum(grad_val); + if (lane_id == leader) { + // FIXME(lzm): atomicAdd is much slower for bf16 and half compared to float, needs further optimization + atomicAdd(&output_b[common_offset], common::cuda::Cast(reduced)); + } + } else if (in_bounds) { + // FIXME(lzm): atomicAdd is much slower for bf16 and half compared to float, needs further optimization + atomicAdd(&output_b[b_offset], common::cuda::Cast(grad_val)); } } - // launch unary operator's backward kernel template void LaunchBackward(Func func, const std::shared_ptr &output, const std::shared_ptr &grad_output, @@ -287,6 +360,33 @@ void LaunchBackward(Func func, const std::shared_ptr &output, const std: output, inputs...); } +// +struct BEffMeta { + std::vector shape; // 只保留 b_shape[i] > 1 且 out_shape[i] > 1 的维 + std::vector out_strides; // 对应输出的 out_strides + std::vector b_strides; // 对应 b 的 strides + bool b_broadcast = false; + }; + +static inline BEffMeta BuildBEffMeta(const std::vector& out_shape, + const std::vector& out_strides, + const std::vector& b_shape, + const std::vector& b_strides) { +BEffMeta m; +const int ndim = static_cast(out_shape.size()); +m.shape.reserve(ndim); m.out_strides.reserve(ndim); m.b_strides.reserve(ndim); + +for (int i = 0; i < ndim; ++i) { +if (out_shape[i] > 1 && b_shape[i] == 1) m.b_broadcast = true; // 有广播 +if (out_shape[i] > 1 && b_shape[i] > 1) { // 这维对 b 有效 + m.shape.push_back(b_shape[i]); + m.out_strides.push_back(out_strides[i]); + m.b_strides.push_back(b_strides[i]); + } +} +return m; +} + // launch binary operator's backward kernel template void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &output_a, @@ -329,7 +429,20 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - + auto b_eff = BuildBEffMeta(out_shape, out_stride_host, b_shape, b_stride_host); + int eff_ndim_b = static_cast(b_eff.shape.size()); + + int64_t *d_b_eff_shape = nullptr, *d_b_eff_out_strides = nullptr, *d_b_eff_b_strides = nullptr; + if (eff_ndim_b > 0) { + cudaMallocAsync(&d_b_eff_shape, eff_ndim_b * sizeof(int64_t), stream); + cudaMallocAsync(&d_b_eff_out_strides, eff_ndim_b * sizeof(int64_t), stream); + cudaMallocAsync(&d_b_eff_b_strides, eff_ndim_b * sizeof(int64_t), stream); + + cudaMemcpyAsync(d_b_eff_shape, b_eff.shape.data(), eff_ndim_b*sizeof(int64_t), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_b_eff_out_strides, b_eff.out_strides.data(), eff_ndim_b*sizeof(int64_t), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_b_eff_b_strides, b_eff.b_strides.data(), eff_ndim_b*sizeof(int64_t), cudaMemcpyHostToDevice, stream); + } + const size_t num_elements = grad_output->NumElements(); if constexpr (std::is_same_v) { @@ -337,20 +450,56 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out [=](dim3 grid, dim3 block, size_t offset, auto... ptrs) { const int NUM_WARPS = BLOCK_SIZE / 32; size_t smem_size = NUM_WARPS * sizeof(cub::WarpReduce::TempStorage); - BinaryBackwardKernel<<>>( - output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, device_a_strides, device_a_shape, - device_b_strides, device_b_shape, device_out_strides, grad_output_ptr, ptrs...); + if (b_eff_broadcast) { + // --- b 有广播,需要归约 + atomicAdd --- + BinaryBackwardKernel_opt<<>>( + output_a_ptr, output_b_ptr, fun_a, fun_b, + ndim, num_elements, + device_a_strides, device_a_shape, + device_b_strides, device_b_shape, + device_out_strides, + d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b, + grad_output_ptr, ptrs...); + } else { + // --- b 无广播,直接写回 --- + BinaryBackwardKernelNoReduce<<>>( + output_a_ptr, output_b_ptr, fun_a, fun_b, + ndim, num_elements, + device_a_strides, device_a_shape, + device_b_strides, device_b_shape, + device_out_strides, + d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b, + grad_output_ptr, ptrs...); + } }, output_a, inputs...); } else if constexpr (std::is_same_v || std::is_same_v) { LaunchKernel( [=](dim3 grid, dim3 block, size_t offset, auto... ptrs) { size_t smem_size = BLOCK_SIZE * sizeof(SharedElem); - BinaryBackwardKernel<<>>( - output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, output_b->NumElements(), - device_a_strides, device_a_shape, device_b_strides, device_b_shape, device_out_strides, - grad_output_ptr, ptrs..., - /*fast_atomics=*/true); + if (b_eff_broadcast) { + // --- b 有广播,需要归约 + atomicAdd --- + BinaryBackwardKernel_opt<<>>( + output_a_ptr, output_b_ptr, fun_a, fun_b, + ndim, num_elements, + device_a_strides, device_a_shape, + device_b_strides, device_b_shape, + device_out_strides, + d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b, + grad_output_ptr, ptrs..., + /*fast_atomics=*/true); + } else { + // --- b 无广播,直接写回 --- + BinaryBackwardKernelNoReduce<<>>( + output_a_ptr, output_b_ptr, fun_a, fun_b, + ndim, num_elements, + device_a_strides, device_a_shape, + device_b_strides, device_b_shape, + device_out_strides, + d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b, + grad_output_ptr, ptrs..., + /*fast_atomics=*/true); + } }, output_a, inputs...); } @@ -843,4 +992,4 @@ REGISTER_CUDA_ELEMENTWISE_KERNEL(DivBackward) REGISTER_CUDA_ELEMENTWISE_KERNEL(SigmoidForward) REGISTER_CUDA_ELEMENTWISE_KERNEL(SigmoidBackward) -#undef REGISTER_CUDA_ELEMENTWISE_KERNEL +#undef REGISTER_CUDA_ELEMENTWISE_KERNEL \ No newline at end of file