Skip to content

Commit 0e80c84

Browse files
[Common] Split cast/gated kernels by scaling mode (#2248)
* Separated gated and dequantize kernels Signed-off-by: Oleg Goncharov <[email protected]> * Separated quantize, dequantize and gated functions Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed lint issues Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed persistent lint issues Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added missing compute capability 10.0 check for Quantize FP8 TMA kernels Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed the issue which was added again by autofix Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite Signed-off-by: Oleg Goncharov <[email protected]> * Removed unsupported template arguments in NVFP4 quantize Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed undefined symbol error Signed-off-by: Oleg Goncharov <[email protected]> * Fixed condition Signed-off-by: Oleg Goncharov <[email protected]> * Fixed CUDA version check Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed arch conditions order Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Oleg Goncharov <[email protected]> * Clean up Signed-off-by: Oleg Goncharov <[email protected]> * Small fix Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Small fix Signed-off-by: Oleg Goncharov <[email protected]> * Fixes per the PR review Signed-off-by: Oleg Goncharov <[email protected]> * Fix Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Split quantize helper into two (FWD and BWD) functions Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list Signed-off-by: Oleg Goncharov <[email protected]> * Enabled fast math for activations by default Signed-off-by: Oleg Goncharov <[email protected]> * Disabled fast math for activations by default Signed-off-by: Oleg Goncharov <[email protected]> --------- Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Oleg Goncharov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 490a5f4 commit 0e80c84

24 files changed

+5073
-3178
lines changed

transformer_engine/common/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ list(APPEND transformer_engine_cuda_sources
168168

169169
list(APPEND transformer_engine_cuda_arch_specific_sources
170170
gemm/cutlass_grouped_gemm.cu
171-
util/cast.cu
171+
cast/cast.cu
172172
activation/gelu.cu
173173
activation/relu.cu
174174
activation/swiglu.cu
@@ -336,8 +336,7 @@ option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --u
336336
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
337337
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
338338
activation/relu.cu
339-
activation/swiglu.cu
340-
util/cast.cu)
339+
activation/swiglu.cu)
341340
endif()
342341

343342
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)

transformer_engine/common/activation/activation_template.h

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,17 @@
1414
#include <cuda_runtime.h>
1515
#include <transformer_engine/activation.h>
1616

17+
#include "../cast/dispatch/gated.cuh"
18+
#include "../cast/dispatch/quantize.cuh"
1719
#include "../common.h"
18-
#include "../util/cast_gated_kernels.cuh"
19-
#include "../util/cast_kernels.cuh"
20-
#include "../util/math.h"
21-
#include "../util/vectorized_pointwise.h"
2220

2321
namespace transformer_engine {
2422

2523
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
2624
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2725
using namespace detail;
28-
constexpr bool IS_DBIAS = false;
29-
constexpr bool IS_DACT = false;
3026
constexpr bool IS_ACT = true;
31-
constexpr NVTETensor dbias = nullptr;
32-
constexpr NVTETensor workspace = nullptr;
33-
constexpr const NVTETensor grad = nullptr;
34-
35-
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
36-
nullptr, stream);
27+
dispatch::quantize_fwd_helper<IS_ACT, Empty, OP>(input, output, nullptr, stream);
3728
}
3829

3930
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
@@ -42,29 +33,25 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
4233
using namespace detail;
4334
constexpr bool IS_DBIAS = false;
4435
constexpr bool IS_DACT = true;
45-
constexpr bool IS_ACT = false;
4636
constexpr NVTETensor dbias = nullptr;
4737
constexpr NVTETensor workspace = nullptr;
4838

49-
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
50-
nullptr, stream);
39+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, OP>(grad, input, output, dbias, workspace,
40+
nullptr, stream);
5141
}
5242

5343
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
5444
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
5545
using namespace detail;
56-
constexpr bool IS_DGATED = false;
57-
constexpr NVTETensor grad = nullptr;
58-
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
46+
dispatch::quantize_gated_fwd_helper<Param, ActOP>(input, output, p, stream);
5947
}
6048

6149
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
6250
ComputeType (*DActOP)(ComputeType, const Param &)>
6351
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
6452
cudaStream_t stream) {
6553
using namespace detail;
66-
constexpr bool IS_DGATED = true;
67-
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
54+
dispatch::quantize_gated_bwd_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
6855
}
6956

7057
} // namespace transformer_engine

transformer_engine/common/activation/gelu.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
2020
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
2121
}
2222

23+
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
24+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
25+
cudaStream_t stream) {
26+
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
27+
using namespace transformer_engine;
28+
29+
constexpr bool IS_DBIAS = true;
30+
constexpr bool IS_DACT = true;
31+
32+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
33+
input, activation_input, output, dbias, workspace, nullptr, stream);
34+
}
35+
2336
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2437
NVTE_API_CALL(nvte_geglu);
2538
using namespace transformer_engine;
@@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
4861
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
4962
}
5063

64+
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
65+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
66+
cudaStream_t stream) {
67+
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
68+
using namespace transformer_engine;
69+
70+
constexpr bool IS_DBIAS = true;
71+
constexpr bool IS_DACT = true;
72+
73+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
74+
input, activation_input, output, dbias, workspace, nullptr, stream);
75+
}
76+
5177
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
5278
NVTE_API_CALL(nvte_qgeglu);
5379
using namespace transformer_engine;

transformer_engine/common/activation/relu.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
2020
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
2121
}
2222

23+
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
24+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
25+
cudaStream_t stream) {
26+
NVTE_API_CALL(nvte_quantize_dbias_drelu);
27+
using namespace transformer_engine;
28+
29+
constexpr bool IS_DBIAS = true;
30+
constexpr bool IS_DACT = true;
31+
32+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
33+
input, activation_input, output, dbias, workspace, nullptr, stream);
34+
}
35+
2336
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2437
NVTE_API_CALL(nvte_reglu);
2538
using namespace transformer_engine;
@@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
4861
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
4962
}
5063

64+
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
65+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
66+
cudaStream_t stream) {
67+
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
68+
using namespace transformer_engine;
69+
70+
constexpr bool IS_DBIAS = true;
71+
constexpr bool IS_DACT = true;
72+
73+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
74+
input, activation_input, output, dbias, workspace, nullptr, stream);
75+
}
76+
5177
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
5278
NVTE_API_CALL(nvte_sreglu);
5379
using namespace transformer_engine;

transformer_engine/common/activation/swiglu.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
2020
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
2121
}
2222

23+
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
24+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
25+
cudaStream_t stream) {
26+
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
27+
using namespace transformer_engine;
28+
29+
constexpr bool IS_DBIAS = true;
30+
constexpr bool IS_DACT = true;
31+
32+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
33+
input, activation_input, output, dbias, workspace, nullptr, stream);
34+
}
35+
2336
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
2437
NVTE_API_CALL(nvte_swiglu);
2538
using namespace transformer_engine;
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include <cuda.h>
8+
#include <cudaTypedefs.h>
9+
#include <cuda_runtime.h>
10+
#include <transformer_engine/cast.h>
11+
#include <transformer_engine/multi_stream.h>
12+
13+
#include "../common.h"
14+
#include "../transpose/cast_transpose.h"
15+
#include "../util/multi_stream.h"
16+
#include "../utils.cuh"
17+
#include "dispatch/dequantize.cuh"
18+
#include "dispatch/quantize.cuh"
19+
#include "transformer_engine/transpose.h"
20+
21+
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
22+
NVTE_API_CALL(nvte_quantize);
23+
using namespace transformer_engine;
24+
25+
constexpr bool IS_ACT = false;
26+
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
27+
}
28+
29+
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
30+
cudaStream_t stream) {
31+
NVTE_API_CALL(nvte_quantize_noop);
32+
using namespace transformer_engine;
33+
34+
// Create config with noop tensor
35+
QuantizationConfig quant_config;
36+
quant_config.noop_tensor = noop;
37+
38+
nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream);
39+
}
40+
41+
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
42+
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
43+
NVTE_API_CALL(nvte_quantize_v2);
44+
using namespace transformer_engine;
45+
46+
constexpr bool IS_ACT = false;
47+
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
48+
}
49+
50+
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
51+
NVTETensor workspace, cudaStream_t stream) {
52+
NVTE_API_CALL(nvte_quantize_dbias);
53+
using namespace transformer_engine;
54+
55+
constexpr bool IS_DBIAS = true;
56+
constexpr bool IS_DACT = false;
57+
constexpr const NVTETensor activation_input = nullptr;
58+
59+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
60+
input, activation_input, output, dbias, workspace, nullptr, stream);
61+
}
62+
63+
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
64+
NVTE_API_CALL(nvte_dequantize);
65+
using namespace transformer_engine;
66+
dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output),
67+
stream);
68+
}
69+
70+
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
71+
const NVTEQuantizationConfig quant_configs,
72+
const size_t num_tensors, cudaStream_t stream) {
73+
NVTE_API_CALL(nvte_multi_tensor_quantize);
74+
using namespace transformer_engine;
75+
76+
constexpr bool IS_ACT = false;
77+
78+
const size_t num_streams = nvte_get_num_compute_streams();
79+
80+
int num_stream_used = std::min(num_streams, num_tensors);
81+
// wait for current stream to finish
82+
NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
83+
for (int s = 0; s < num_stream_used; s++) {
84+
NVTE_CHECK_CUDA(
85+
cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
86+
}
87+
88+
for (int i = 0; i < num_tensors; i++) {
89+
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
90+
inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams));
91+
}
92+
93+
// record events on compute streams
94+
for (int s = 0; s < num_stream_used; s++) {
95+
NVTE_CHECK_CUDA(
96+
cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
97+
}
98+
// wait for all compute streams to finish
99+
for (int s = 0; s < num_stream_used; s++) {
100+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
101+
}
102+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
/*! \file common.cuh
8+
* \brief Common functions in quantize.
9+
*/
10+
11+
#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
12+
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
13+
14+
#include <cuda.h>
15+
#include <cudaTypedefs.h>
16+
#include <cuda_runtime.h>
17+
#include <transformer_engine/transformer_engine.h>
18+
19+
#include "../../common.h"
20+
#include "../../utils.cuh"
21+
22+
namespace transformer_engine {
23+
namespace dispatch {
24+
namespace common {
25+
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
26+
const size_t N = product(t->data.shape);
27+
const bool isFullTile = (N % elems_per_block == 0);
28+
return isFullTile;
29+
}
30+
31+
inline bool dimensions_supported_by_TMA(const Tensor *const t) {
32+
const size_t cols = t->flat_last_dim();
33+
constexpr size_t TMA_bytes = 16;
34+
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
35+
return cols % alignment_requirement == 0;
36+
}
37+
38+
namespace kernel {
39+
40+
constexpr size_t THREADS_PER_BLOCK = 256;
41+
template <int nvec, typename OType>
42+
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
43+
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
44+
const size_t rows, const size_t cols) {
45+
using ComputeVec = Vec<float, nvec>;
46+
using OutputVec = Vec<OType, nvec>;
47+
48+
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
49+
50+
if (thread_id * nvec >= cols) {
51+
return;
52+
}
53+
54+
const float *const thread_in_base = dbias_partial + thread_id * nvec;
55+
OType *const thread_out_base = dbias_output + thread_id * nvec;
56+
57+
ComputeVec ldg_vec;
58+
ComputeVec acc_vec;
59+
acc_vec.clear();
60+
for (int i = 0; i < rows; ++i) {
61+
ldg_vec.load_from(thread_in_base + i * cols);
62+
#pragma unroll
63+
for (int e = 0; e < nvec; ++e) {
64+
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
65+
}
66+
}
67+
68+
OutputVec stg_vec;
69+
#pragma unroll
70+
for (int e = 0; e < nvec; ++e) {
71+
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
72+
}
73+
stg_vec.store_to(thread_out_base);
74+
}
75+
} // namespace kernel
76+
77+
template <typename IType>
78+
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
79+
cudaStream_t stream) {
80+
using namespace kernel;
81+
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
82+
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
83+
84+
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
85+
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec);
86+
87+
reduce_dbias_kernel<reduce_dbias_nvec, IType>
88+
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>(
89+
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
90+
NVTE_CHECK_CUDA(cudaGetLastError());
91+
}
92+
93+
} // namespace common
94+
} // namespace dispatch
95+
} // namespace transformer_engine
96+
97+
#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_

0 commit comments

Comments
 (0)