-
Notifications
You must be signed in to change notification settings - Fork 540
[Common] Split cast/gated kernels by scaling mode #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
36533fe
d7d8e43
30484be
1bac1f8
40cc697
bcdbee0
67e32c3
a72d53a
3e5edd2
bcbdf8a
3dea6bd
4a05df0
2433529
3526e54
6c0ed53
aa5ec2e
bcd55f6
7158b9d
0c4314e
41416b3
91ea154
e14ae55
f7225e9
37747dd
9afdba1
b764dea
703556c
ec58510
b84301b
ebcfafe
80df2fd
deb012b
6142ff7
25e9b48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file common.cuh | ||
| * \brief Common functions in quantize. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: file comment says "in quantize" but this lives in |
||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: header guard mismatch – file is |
||
| #define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
|
|
||
| #include <cuda.h> | ||
| #include <cudaTypedefs.h> | ||
| #include <cuda_runtime.h> | ||
| #include <transformer_engine/transformer_engine.h> | ||
|
|
||
| #include "../../common.h" | ||
| #include "../../utils.cuh" | ||
|
|
||
| namespace transformer_engine { | ||
| namespace dispatch { | ||
| namespace common { | ||
| inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { | ||
| const size_t N = product(t->data.shape); | ||
| const bool isFullTile = (N % elems_per_block == 0); | ||
| return isFullTile; | ||
| } | ||
|
|
||
| inline bool dimensions_supported_by_TMA(const Tensor *const t) { | ||
| const size_t cols = t->flat_last_dim(); | ||
| constexpr size_t TMA_bytes = 16; | ||
| const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); | ||
| return cols % alignment_requirement == 0; | ||
| } | ||
|
|
||
| namespace kernel { | ||
|
|
||
| constexpr size_t THREADS_PER_BLOCK = 256; | ||
| template <int nvec, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
| reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, | ||
| const size_t rows, const size_t cols) { | ||
| using ComputeVec = Vec<float, nvec>; | ||
| using OutputVec = Vec<OType, nvec>; | ||
|
|
||
| const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
|
||
| if (thread_id * nvec >= cols) { | ||
| return; | ||
| } | ||
|
|
||
| const float *const thread_in_base = dbias_partial + thread_id * nvec; | ||
| OType *const thread_out_base = dbias_output + thread_id * nvec; | ||
|
|
||
| ComputeVec ldg_vec; | ||
| ComputeVec acc_vec; | ||
| acc_vec.clear(); | ||
| for (int i = 0; i < rows; ++i) { | ||
| ldg_vec.load_from(thread_in_base + i * cols); | ||
| #pragma unroll | ||
| for (int e = 0; e < nvec; ++e) { | ||
| acc_vec.data.elt[e] += ldg_vec.data.elt[e]; | ||
| } | ||
| } | ||
|
|
||
| OutputVec stg_vec; | ||
| #pragma unroll | ||
| for (int e = 0; e < nvec; ++e) { | ||
| stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]); | ||
| } | ||
| stg_vec.store_to(thread_out_base); | ||
| } | ||
| } // namespace kernel | ||
|
|
||
| template <typename IType> | ||
| void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, | ||
|
Comment on lines
+77
to
+78
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: template parameter |
||
| cudaStream_t stream) { | ||
| using namespace kernel; | ||
| constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 | ||
| constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); | ||
|
|
||
| NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); | ||
| const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); | ||
|
|
||
| reduce_dbias_kernel<reduce_dbias_nvec, IType> | ||
| <<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>( | ||
| reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols); | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
||
| } // namespace common | ||
| } // namespace dispatch | ||
| } // namespace transformer_engine | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file dequantize.cuh | ||
| * \brief Dequantize dispatcher. | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ | ||
| #define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ | ||
|
|
||
| #include <transformer_engine/transformer_engine.h> | ||
|
|
||
| #include "../../common.h" | ||
| #include "../fp8/dequantize_fp8.cuh" | ||
| #include "../mxfp8/dequantize_mxfp8.cuh" | ||
| #include "../nvfp4/dequantize_nvfp4.cuh" | ||
|
|
||
| namespace transformer_engine { | ||
| namespace dispatch { | ||
|
|
||
| inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { | ||
| CheckInputTensor(input, "cast_input"); | ||
| CheckOutputTensor(*output, "cast_output"); | ||
|
|
||
| switch (input.scaling_mode) { | ||
| case NVTE_DELAYED_TENSOR_SCALING: { | ||
| NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); | ||
| NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); | ||
| NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); | ||
| fp8::dequantize(input, output, stream); | ||
Oleg-Goncharov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| break; | ||
| } | ||
| case NVTE_MXFP8_1D_SCALING: { | ||
| if (is_supported_by_CC_100()) { | ||
| mxfp8::dequantize(input, output, stream); | ||
| } else { | ||
| NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); | ||
| } | ||
| break; | ||
| } | ||
| case NVTE_NVFP4_1D_SCALING: { | ||
| nvfp4::dequantize(input, output, stream); | ||
| break; | ||
| } | ||
| default: | ||
| NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); | ||
| } | ||
| } | ||
|
|
||
| } // namespace dispatch | ||
| } // namespace transformer_engine | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: indentation of the continuation line changed from 4 to 1 space, which deviates from the project's
.clang-format(ContinuationIndentWidth: 4)