Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added BinaryBackwardKernel分析与优化.pdf
Binary file not shown.
257 changes: 203 additions & 54 deletions infini_train/src/kernels/cuda/elementwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,46 @@ __device__ inline int64_t CalcOffset(int64_t idx, int ndim, const int64_t *strid
return offset;
}

//更新版的offset计算
__device__ __forceinline__
int64_t CalcBOffset_lvl0(uint64_t idx,
int eff_ndim_b,
const int64_t* __restrict__ eff_out_strides_b,
const int64_t* __restrict__ eff_b_strides_b,
const int64_t* __restrict__ eff_shape_b) {
int64_t off = 0;
#pragma unroll
for (int j = 0; j < eff_ndim_b; ++j) {
int64_t out_index = (idx / eff_out_strides_b[j]) % eff_shape_b[j];
off += out_index * eff_b_strides_b[j];
}
return off; // 若 eff_ndim_b==0,off 恒 0
}

template <typename T, typename FuncA, typename FuncB>
__global__ void BinaryBackwardKernelNoReduce(
T *output_a, T *output_b, FuncA fn_a, FuncB fn_b,
int ndim, size_t num_elements,
const int64_t *a_strides, const int64_t *a_shape,
const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides,
const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b,
const T *grad_output, const T *input_a, const T *input_b)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_elements) return;

//考虑到a和b都没有进行广播,那么a和b和out_put就都相同
const int64_t a_offset = idx;
const int64_t b_offset = idx;

const T a_val = input_a ? input_a[a_offset] : T(0);
const T b_val = input_b ? input_b[b_offset] : T(0);

output_a[a_offset] = Mul<T>(grad_output[idx], fn_a(a_val, b_val));
const T db = common::cuda::Cast<T>(Mul<T>(grad_output[idx], fn_b(a_val, b_val)));
output_b[b_offset] = db;
}

template <typename T, typename Func>
__global__ void BinaryForwardKernel(T *output, Func fn, int ndim, const int64_t *a_strides, const int64_t *a_shape,
const int64_t *b_strides, const int64_t *b_shape, const int64_t *out_strides,
Expand Down Expand Up @@ -148,10 +188,13 @@ __global__ void UnaryBackwardKernel(T *output, Func fn, size_t num_elements, siz
// Backward kernel for binary operators
// TODO(lzm): determining and passing b_is_broadcasted from the caller; optimize further
template <typename T, typename FuncA, typename FuncB>
__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, int ndim, size_t num_elements,
const int64_t *a_strides, const int64_t *a_shape, const int64_t *b_strides,
const int64_t *b_shape, const int64_t *out_strides, const T *grad_output,
const T *input_a, const T *input_b) {
__global__ void BinaryBackwardKernel_opt(
T *output_a, T *output_b, FuncA fn_a, FuncB fn_b,
int ndim, size_t num_elements,
const int64_t *a_strides, const int64_t *a_shape,
const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides,
const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b,
const T *grad_output, const T *input_a, const T *input_b) {
extern __shared__ char shared_memory[];
const int tid = threadIdx.x;
const int warp_id = tid / 32;
Expand All @@ -168,8 +211,8 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB
float grad_val = 0.0f;

if (in_bounds) {
a_offset = CalcOffset(idx, ndim, a_strides, a_shape, out_strides);
b_offset = CalcOffset(idx, ndim, b_strides, b_shape, out_strides);
a_offset = idx;
b_offset = CalcBOffset_lvl0(idx, eff_ndim_b, eff_out_strides_b, eff_b_strides_b, eff_shape_b);
a_val = input_a ? input_a[a_offset] : T(0);
b_val = input_b ? input_b[b_offset] : T(0);
output_a[a_offset] = Mul<T>(grad_output[idx], fn_a(a_val, b_val));
Expand Down Expand Up @@ -213,64 +256,94 @@ struct SharedElem {
int64_t offset;
float grad;
};

// NOTE(dcj): Specialized BinaryBackwardKernel for low-precision types (__half / bfloat16)
template <typename T, typename FuncA, typename FuncB>
__global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB fn_b, int ndim, size_t num_elements,
size_t b_num_elements, const int64_t *a_strides, const int64_t *a_shape,
const int64_t *b_strides, const int64_t *b_shape, const int64_t *out_strides,
const T *grad_output, const T *input_a, const T *input_b, bool fast_atomics) {
extern __shared__ char shared_memory[];
// Shared memory stores b_offset and grad_val
SharedElem *smem = reinterpret_cast<SharedElem *>(shared_memory);
__global__ void BinaryBackwardKernelNoReduce( int ndim, size_t num_elements,
const int64_t *a_strides, const int64_t *a_shape,
const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides,
const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b,
const T *grad_output, const T *input_a, const T *input_b, bool fast_atomics)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_elements) return;

//考虑到a和b都没有进行广播,那么a和b和out_put就都相同
const int64_t a_offset = idx;
const int64_t b_offset = idx;

const T a_val = input_a ? input_a[a_offset] : T(0);
const T b_val = input_b ? input_b[b_offset] : T(0);

output_a[a_offset] = Mul<T>(grad_output[idx], fn_a(a_val, b_val));
const T db = common::cuda::Cast<T>(Mul<T>(grad_output[idx], fn_b(a_val, b_val)));
output_b[b_offset] = db;
}


template <typename T, typename FuncA, typename FuncB>
__global__ void BinaryBackwardKernel_opt( T *output_a, T *output_b, FuncA fn_a, FuncB fn_b,
int ndim, size_t num_elements,
const int64_t *a_strides, const int64_t *a_shape,
const int64_t * /*b_strides*/, const int64_t * /*b_shape*/, const int64_t *out_strides,
const int64_t *eff_out_strides_b, const int64_t *eff_b_strides_b, const int64_t *eff_shape_b, int eff_ndim_b,
const T *grad_output, const T *input_a, const T *input_b) {
extern __shared__ char shared_memory[];
const int tid = threadIdx.x;
const int block_threads = blockDim.x;
const int global_idx = blockIdx.x * blockDim.x + tid;
bool in_bounds = (global_idx < num_elements);
const int warp_id = tid / 32;
const int lane_id = tid % 32;

using WarpReduce = cub::WarpReduce<float>;
WarpReduce::TempStorage *temp_storage = reinterpret_cast<WarpReduce::TempStorage *>(shared_memory);

size_t idx = blockIdx.x * blockDim.x + tid;
bool in_bounds = (idx < num_elements);

// Each thread calculates its own a_offset and b_offset
int64_t a_offset = 0, b_offset = 0;
float grad_val = 0.0f;
T a_val = T(0), b_val = T(0);
float grad_val = 0.0f;

if (in_bounds) {
a_offset = CalcOffset(global_idx, ndim, a_strides, a_shape, out_strides);
b_offset = CalcOffset(global_idx, ndim, b_strides, b_shape, out_strides);

a_offset = idx;
b_offset = CalcBOffset_lvl0(idx, eff_ndim_b, eff_out_strides_b, eff_b_strides_b, eff_shape_b);
a_val = input_a ? input_a[a_offset] : T(0);
b_val = input_b ? input_b[b_offset] : T(0);

// Compute gradient contribution for output_a
output_a[a_offset] = Mul<T>(grad_output[global_idx], fn_a(a_val, b_val));
// Store gradient contribution for output_b in float for accumulation
grad_val = common::cuda::Cast<float>(Mul<T>(grad_output[global_idx], fn_b(a_val, b_val)));
output_a[a_offset] = Mul<T>(grad_output[idx], fn_a(a_val, b_val));
grad_val = common::cuda::Cast<float>(Mul<T>(grad_output[idx], fn_b(a_val, b_val)));
}

// Write each thread's b_offset and grad_val into shared memory
smem[tid].offset = in_bounds ? b_offset : -1;
smem[tid].grad = grad_val;
unsigned active_mask = __ballot_sync(0xFFFFFFFF, in_bounds);
if (!active_mask) {
return;
}

__syncthreads();
int leader = __ffs(active_mask) - 1;
int64_t common_offset = __shfl_sync(active_mask, b_offset, leader);

// Block-level reduction: threads check if they can accumulate
for (int stride = 1; stride < block_threads; stride *= 2) {
__syncthreads();
if (tid % (2 * stride) == 0 && (tid + stride) < block_threads) {
if (smem[tid].offset == smem[tid + stride].offset && smem[tid].offset != -1) {
smem[tid].grad += smem[tid + stride].grad;
smem[tid + stride].offset = -1;
}
// Check if all active threads share common b_offset
bool warp_uniform = true;
for (int i = 0; i < 32; ++i) {
if (!(active_mask & (1 << i))) {
continue;
}
int64_t offset_i = __shfl_sync(active_mask, b_offset, i);
if (offset_i != common_offset) {
warp_uniform = false;
break;
}
}
__syncthreads();

// Write final result back to global memory
if (in_bounds && smem[tid].offset != -1) {
fastAtomicAdd<T, size_t>(output_b, smem[tid].offset, b_num_elements, common::cuda::Cast<T>(smem[tid].grad),
fast_atomics);
if (warp_uniform) {
float reduced = WarpReduce(temp_storage[warp_id]).Sum(grad_val);
if (lane_id == leader) {
// FIXME(lzm): atomicAdd is much slower for bf16 and half compared to float, needs further optimization
atomicAdd(&output_b[common_offset], common::cuda::Cast<T>(reduced));
}
} else if (in_bounds) {
// FIXME(lzm): atomicAdd is much slower for bf16 and half compared to float, needs further optimization
atomicAdd(&output_b[b_offset], common::cuda::Cast<T>(grad_val));
}
}

// launch unary operator's backward kernel
template <size_t BLOCK_SIZE, typename T, typename Func, typename... Inputs>
void LaunchBackward(Func func, const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &grad_output,
Expand All @@ -287,6 +360,33 @@ void LaunchBackward(Func func, const std::shared_ptr<Tensor> &output, const std:
output, inputs...);
}

//
struct BEffMeta {
std::vector<int64_t> shape; // 只保留 b_shape[i] > 1 且 out_shape[i] > 1 的维
std::vector<int64_t> out_strides; // 对应输出的 out_strides
std::vector<int64_t> b_strides; // 对应 b 的 strides
bool b_broadcast = false;
};

static inline BEffMeta BuildBEffMeta(const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& out_strides,
const std::vector<int64_t>& b_shape,
const std::vector<int64_t>& b_strides) {
BEffMeta m;
const int ndim = static_cast<int>(out_shape.size());
m.shape.reserve(ndim); m.out_strides.reserve(ndim); m.b_strides.reserve(ndim);

for (int i = 0; i < ndim; ++i) {
if (out_shape[i] > 1 && b_shape[i] == 1) m.b_broadcast = true; // 有广播
if (out_shape[i] > 1 && b_shape[i] > 1) { // 这维对 b 有效
m.shape.push_back(b_shape[i]);
m.out_strides.push_back(out_strides[i]);
m.b_strides.push_back(b_strides[i]);
}
}
return m;
}

// launch binary operator's backward kernel
template <size_t BLOCK_SIZE, typename T, typename FuncA, typename FuncB, typename... Inputs>
void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &output_a,
Expand Down Expand Up @@ -329,28 +429,77 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end());

cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream);

auto b_eff = BuildBEffMeta(out_shape, out_stride_host, b_shape, b_stride_host);
int eff_ndim_b = static_cast<int>(b_eff.shape.size());

int64_t *d_b_eff_shape = nullptr, *d_b_eff_out_strides = nullptr, *d_b_eff_b_strides = nullptr;
if (eff_ndim_b > 0) {
cudaMallocAsync(&d_b_eff_shape, eff_ndim_b * sizeof(int64_t), stream);
cudaMallocAsync(&d_b_eff_out_strides, eff_ndim_b * sizeof(int64_t), stream);
cudaMallocAsync(&d_b_eff_b_strides, eff_ndim_b * sizeof(int64_t), stream);

cudaMemcpyAsync(d_b_eff_shape, b_eff.shape.data(), eff_ndim_b*sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(d_b_eff_out_strides, b_eff.out_strides.data(), eff_ndim_b*sizeof(int64_t), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(d_b_eff_b_strides, b_eff.b_strides.data(), eff_ndim_b*sizeof(int64_t), cudaMemcpyHostToDevice, stream);
}

const size_t num_elements = grad_output->NumElements();

if constexpr (std::is_same_v<T, float>) {
LaunchKernel<BLOCK_SIZE, T>(
[=](dim3 grid, dim3 block, size_t offset, auto... ptrs) {
const int NUM_WARPS = BLOCK_SIZE / 32;
size_t smem_size = NUM_WARPS * sizeof(cub::WarpReduce<float>::TempStorage);
BinaryBackwardKernel<<<grid, block, smem_size, stream>>>(
output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, device_a_strides, device_a_shape,
device_b_strides, device_b_shape, device_out_strides, grad_output_ptr, ptrs...);
if (b_eff_broadcast) {
// --- b 有广播,需要归约 + atomicAdd ---
BinaryBackwardKernel_opt<<<grid, block, smem_size, stream>>>(
output_a_ptr, output_b_ptr, fun_a, fun_b,
ndim, num_elements,
device_a_strides, device_a_shape,
device_b_strides, device_b_shape,
device_out_strides,
d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b,
grad_output_ptr, ptrs...);
} else {
// --- b 无广播,直接写回 ---
BinaryBackwardKernelNoReduce<<<grid, block, 0, stream>>>(
output_a_ptr, output_b_ptr, fun_a, fun_b,
ndim, num_elements,
device_a_strides, device_a_shape,
device_b_strides, device_b_shape,
device_out_strides,
d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b,
grad_output_ptr, ptrs...);
}
},
output_a, inputs...);
} else if constexpr (std::is_same_v<T, __half> || std::is_same_v<T, __nv_bfloat16>) {
LaunchKernel<BLOCK_SIZE, T>(
[=](dim3 grid, dim3 block, size_t offset, auto... ptrs) {
size_t smem_size = BLOCK_SIZE * sizeof(SharedElem);
BinaryBackwardKernel<<<grid, block, smem_size, stream>>>(
output_a_ptr, output_b_ptr, fun_a, fun_b, ndim, num_elements, output_b->NumElements(),
device_a_strides, device_a_shape, device_b_strides, device_b_shape, device_out_strides,
grad_output_ptr, ptrs...,
/*fast_atomics=*/true);
if (b_eff_broadcast) {
// --- b 有广播,需要归约 + atomicAdd ---
BinaryBackwardKernel_opt<<<grid, block, smem_size, stream>>>(
output_a_ptr, output_b_ptr, fun_a, fun_b,
ndim, num_elements,
device_a_strides, device_a_shape,
device_b_strides, device_b_shape,
device_out_strides,
d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b,
grad_output_ptr, ptrs...,
/*fast_atomics=*/true);
} else {
// --- b 无广播,直接写回 ---
BinaryBackwardKernelNoReduce<<<grid, block, 0, stream>>>(
output_a_ptr, output_b_ptr, fun_a, fun_b,
ndim, num_elements,
device_a_strides, device_a_shape,
device_b_strides, device_b_shape,
device_out_strides,
d_b_eff_out_strides, d_b_eff_b_strides, d_b_eff_shape, eff_ndim_b,
grad_output_ptr, ptrs...,
/*fast_atomics=*/true);
}
},
output_a, inputs...);
}
Expand Down Expand Up @@ -843,4 +992,4 @@ REGISTER_CUDA_ELEMENTWISE_KERNEL(DivBackward)
REGISTER_CUDA_ELEMENTWISE_KERNEL(SigmoidForward)
REGISTER_CUDA_ELEMENTWISE_KERNEL(SigmoidBackward)

#undef REGISTER_CUDA_ELEMENTWISE_KERNEL
#undef REGISTER_CUDA_ELEMENTWISE_KERNEL