From d7a1d637fd92a87cf7c647ddc10e14c1df0d79b9 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 5 Apr 2025 10:02:56 -0700 Subject: [PATCH 01/10] add moe_kernel_utils (cdiv etc) --- .../kernels/moe_sorting/moe_kernel_utils.h | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h diff --git a/torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h b/torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h new file mode 100644 index 000000000..0c7bf30ce --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h @@ -0,0 +1,72 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +/* + * Cuda kernel utils file for MoE related kernels + * basically let's not reinvent the wheel for core functions... + * ====================== + * cdiv + * grid_1d + * grid_2d + * calc_shared_memory_size + * ======================= + + */ + +#pragma once + +#include +#include + +namespace moe_kernel_utils { +/** + * cdiv - Ceiling division - grid and block size calc support + * + * @param numerator Number of elements to process + * @param denominator Number of elements per thread/block + * @return Ceiling of the division (usually number of blocks needed) + */ +inline int cdiv(int numerator, int denominator) { + return (numerator + denominator - 1) / denominator; +} + +/** + * grid_1d - calculate 1D grid size with upper limit + * + * @param elements Number of elements to process + * @param threads_per_block Number of threads per block + * @param max_blocks Upper limit of blocks (default to 256 for now) + * @return optimal number of blocks for the 1d grid + */ +inline int grid_1d(int elements, int threads_per_block, int max_blocks = 256) { + return std::min(max_blocks, cdiv(elements, threads_per_block)); +} + +/** + * grid_2d - calcuate 2d grid based on input dimensions (x,y) + * @param dim_x 1st dimension size - usually rows + * @param dim_y 2nd dimension (usually features/columns) + * @param block_dim_x Number of threads per block in x dimension + * @param block_dim_y Number of threads per block in y dimension + * @return dim3 with grid dimensions + */ +inline dim3 grid_2d(int dim_x, int dim_y, int block_dim_x, int block_dim_y) { + return dim3(cdiv(dim_x, block_dim_x), cdiv(dim_y, block_dim_y)); +} + +/** +* calc_shared_memory_size - calculate shared memory size needed for given type +and count +* +* @param T Type to allocate for +* @param count Num elements +* @return Size in bytes for shared memory allocation + + */ +template inline size_t calc_shared_memory_size(int count) { + return count * sizeof(T); +} +} // namespace moe_kernel_utils From 30d215b60d2b4271318b1ecf78ea6bd611280aef Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 5 Apr 2025 11:58:25 -0700 Subject: [PATCH 02/10] start sorting kernels --- .../moe_sorting/token_sorting_kernels.cu | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu diff --git a/torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu b/torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu new file mode 100644 index 000000000..f8475754b --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu @@ -0,0 +1,34 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +/* + * Token sorting kernels + * sequential and parallel scans + */ + +#include +#include +#include +#include + +#include "moe_kernel_utils.h" + +// our utility namespace +using namespace moe_kernel_utils; + +// +// kernels for sorting tokens by expert assignment +// + +__global__ void sort_tokens_by_expert_kernel( + + ) + + // gather kernel - move tokens to sorted indices + template