-
Notifications
You must be signed in to change notification settings - Fork 516
[common] Split cast/gated kernels by scaling mode #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Oleg-Goncharov
wants to merge
14
commits into
NVIDIA:main
Choose a base branch
from
Oleg-Goncharov:pr_cast_kernels_cleanup
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,544
−3,237
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
43fac88
Separated gated and dequantize kernels
Oleg-Goncharov b5c5a44
Separated quantize, dequantize and gated functions
Oleg-Goncharov 4ef014d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a61e41a
Fixed lint issues
Oleg-Goncharov b9bc847
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 591ffc2
Fixed persistent lint issues
Oleg-Goncharov b15d1d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d4928b1
Added missing compute capability 10.0 check for Quantize FP8 TMA kernels
Oleg-Goncharov a5ccfa0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 92d0973
Fixed the issue which was added again by autofix
Oleg-Goncharov 018ab71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 387cceb
Merge branch 'main' into pr_cast_kernels_cleanup
Oleg-Goncharov 0eef950
Changed files description. Completely removed non-identity activation…
Oleg-Goncharov fa92095
Merge branch 'main' into pr_cast_kernels_cleanup
Oleg-Goncharov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
/*! \file common.cuh | ||
* \brief Common functions in quantize. | ||
*/ | ||
|
||
#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
|
||
#include <cuda.h> | ||
#include <cudaTypedefs.h> | ||
#include <cuda_runtime.h> | ||
#include <transformer_engine/transformer_engine.h> | ||
|
||
#include "../../common.h" | ||
#include "../../utils.cuh" | ||
|
||
namespace transformer_engine { | ||
namespace dispatch { | ||
namespace common { | ||
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { | ||
const size_t N = product(t->data.shape); | ||
const bool isFullTile = (N % elems_per_block == 0); | ||
return isFullTile; | ||
} | ||
|
||
inline bool dimensions_supported_by_TMA(const Tensor *const t) { | ||
const size_t cols = t->flat_last_dim(); | ||
constexpr size_t TMA_bytes = 16; | ||
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); | ||
return cols % alignment_requirement == 0; | ||
} | ||
|
||
namespace kernel { | ||
|
||
constexpr size_t THREADS_PER_BLOCK = 256; | ||
template <int nvec, typename OType> | ||
__global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, | ||
const size_t rows, const size_t cols) { | ||
using ComputeVec = Vec<float, nvec>; | ||
using OutputVec = Vec<OType, nvec>; | ||
|
||
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
if (thread_id * nvec >= cols) { | ||
return; | ||
} | ||
|
||
const float *const thread_in_base = dbias_partial + thread_id * nvec; | ||
OType *const thread_out_base = dbias_output + thread_id * nvec; | ||
|
||
ComputeVec ldg_vec; | ||
ComputeVec acc_vec; | ||
acc_vec.clear(); | ||
for (int i = 0; i < rows; ++i) { | ||
ldg_vec.load_from(thread_in_base + i * cols); | ||
#pragma unroll | ||
for (int e = 0; e < nvec; ++e) { | ||
acc_vec.data.elt[e] += ldg_vec.data.elt[e]; | ||
} | ||
} | ||
|
||
OutputVec stg_vec; | ||
#pragma unroll | ||
for (int e = 0; e < nvec; ++e) { | ||
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]); | ||
} | ||
stg_vec.store_to(thread_out_base); | ||
} | ||
} // namespace kernel | ||
|
||
template <typename IType> | ||
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, | ||
cudaStream_t stream) { | ||
using namespace kernel; | ||
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 | ||
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); | ||
|
||
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); | ||
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); | ||
|
||
reduce_dbias_kernel<reduce_dbias_nvec, IType> | ||
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>( | ||
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols); | ||
NVTE_CHECK_CUDA(cudaGetLastError()); | ||
} | ||
|
||
} // namespace common | ||
} // namespace dispatch | ||
} // namespace transformer_engine | ||
|
||
#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
/*! \file dequantize.cuh | ||
* \brief Dequantize dispatcher. | ||
*/ | ||
|
||
#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ | ||
#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ | ||
|
||
#include <transformer_engine/transformer_engine.h> | ||
|
||
#include "../../common.h" | ||
#include "../fp8/dequantize_fp8.cuh" | ||
#include "../mxfp8/dequantize_mxfp8.cuh" | ||
#include "../nvfp4/dequantize_nvfp4.cuh" | ||
|
||
namespace transformer_engine { | ||
namespace dispatch { | ||
|
||
inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { | ||
CheckInputTensor(input, "cast_input"); | ||
CheckOutputTensor(*output, "cast_output"); | ||
|
||
switch (input.scaling_mode) { | ||
case NVTE_DELAYED_TENSOR_SCALING: { | ||
fp8::dequantize(input, output, stream); | ||
break; | ||
} | ||
case NVTE_MXFP8_1D_SCALING: { | ||
if (is_supported_by_CC_100()) { | ||
mxfp8::dequantize(input, output, stream); | ||
} else { | ||
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); | ||
} | ||
break; | ||
} | ||
case NVTE_NVFP4_1D_SCALING: { | ||
nvfp4::dequantize(input, output, stream); | ||
break; | ||
} | ||
default: | ||
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); | ||
} | ||
} | ||
|
||
} // namespace dispatch | ||
} // namespace transformer_engine | ||
|
||
#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support activation-cast kernels for NVFP4, we should test them somewhere. Maybe not in this file, but somewhere.