Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
36533fe
Separated gated and dequantize kernels
Oleg-Goncharov Oct 6, 2025
d7d8e43
Separated quantize, dequantize and gated functions
Oleg-Goncharov Oct 7, 2025
30484be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
1bac1f8
Fixed lint issues
Oleg-Goncharov Oct 8, 2025
40cc697
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
bcdbee0
Fixed persistent lint issues
Oleg-Goncharov Oct 8, 2025
67e32c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
a72d53a
Added missing compute capability 10.0 check for Quantize FP8 TMA kernels
Oleg-Goncharov Oct 9, 2025
3e5edd2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2025
bcbdf8a
Fixed the issue which was added again by autofix
Oleg-Goncharov Oct 9, 2025
3dea6bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2025
4a05df0
Changed files description. Completely removed non-identity activation…
Oleg-Goncharov Oct 9, 2025
2433529
Removed unsupported template arguments in NVFP4 quantize
Oleg-Goncharov Oct 14, 2025
3526e54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
6c0ed53
Fixed undefined symbol error
Oleg-Goncharov Oct 14, 2025
aa5ec2e
Fixed condition
Oleg-Goncharov Oct 14, 2025
bcd55f6
Fixed CUDA version check
Oleg-Goncharov Oct 16, 2025
7158b9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2025
0c4314e
Changed arch conditions order
Oleg-Goncharov Oct 16, 2025
41416b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2025
91ea154
Fix
Oleg-Goncharov Oct 16, 2025
e14ae55
Clean up
Oleg-Goncharov Oct 24, 2025
f7225e9
Small fix
Oleg-Goncharov Oct 24, 2025
37747dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
9afdba1
Small fix
Oleg-Goncharov Oct 24, 2025
b764dea
Fixes per the PR review
Oleg-Goncharov Oct 28, 2025
703556c
Fix
Oleg-Goncharov Oct 28, 2025
ec58510
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
b84301b
Split quantize helper into two (FWD and BWD) functions
Oleg-Goncharov Oct 29, 2025
ebcfafe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
80df2fd
Merge branch 'main' into pr_cast_kernels_cleanup
Oleg-Goncharov Oct 30, 2025
deb012b
Moved activation functions from cast.cu. Removed cast.cu from the fas…
Oleg-Goncharov Oct 30, 2025
6142ff7
Enabled fast math for activations by default
Oleg-Goncharov Oct 30, 2025
25e9b48
Disabled fast math for activations by default
Oleg-Goncharov Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ list(APPEND transformer_engine_cuda_sources

list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
Expand Down Expand Up @@ -337,7 +337,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu)
cast/cast.cu)
endif()

foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
Expand Down
27 changes: 7 additions & 20 deletions transformer_engine/common/activation/activation_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,17 @@
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>

#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"

namespace transformer_engine {

template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, OP>(input, output, nullptr, stream);
}

template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
Expand All @@ -42,29 +33,25 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;

quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, OP>(grad, input, output, dbias, workspace,
nullptr, stream);
}

template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
dispatch::quantize_gated_fwd_helper<Param, ActOP>(input, output, p, stream);
}

template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
dispatch::quantize_gated_bwd_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
}

} // namespace transformer_engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,22 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>

#include <cfloat>
#include <limits>
#include <mutex>
#include <string>

#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "dispatch/dequantize.cuh"
#include "dispatch/gated.cuh"
#include "dispatch/quantize.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"

void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize);
using namespace transformer_engine;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}

void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
Expand All @@ -59,15 +45,8 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}

void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
Expand All @@ -77,11 +56,10 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move those functions involving activations to the activation-specific files? That way we could make sure that we use fast math only for the activations (and then maybe actually turn it on by default?) and not for the entire cast.cu file.

Expand All @@ -92,10 +70,9 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
Expand All @@ -106,10 +83,9 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
Expand All @@ -120,10 +96,9 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
Expand All @@ -134,10 +109,9 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
Expand All @@ -148,16 +122,16 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output),
stream);
}

void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
Expand All @@ -166,12 +140,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

const size_t num_streams = nvte_get_num_compute_streams();

Expand All @@ -184,9 +153,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
}

for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
inputs[i], outputs[i], nullptr, detail::get_compute_stream(i % num_streams));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: In multi-tensor quantize, quant_configs parameter is declared on line 138 but never used. Should each call to quantize_fwd_helper pass quant_configs (or an indexed config)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should pass the quant_config (quant_configs is actually misleading name since it suggests there are multiple of those) to those quantize functions.

}

// record events on compute streams
Expand Down
97 changes: 97 additions & 0 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file common.cuh
* \brief Common functions in quantize.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: file comment says "in quantize" but this lives in cast/core/ – update to "Common functions in cast." or similar

*/

#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: header guard mismatch – file is cast/core/common.cuh but guard is QUANTIZE_CORE_COMMON_CUH_

#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_

#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>

#include "../../common.h"
#include "../../utils.cuh"

namespace transformer_engine {
namespace dispatch {
namespace common {
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
return isFullTile;
}

inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}

namespace kernel {

constexpr size_t THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;

const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id * nvec >= cols) {
return;
}

const float *const thread_in_base = dbias_partial + thread_id * nvec;
OType *const thread_out_base = dbias_output + thread_id * nvec;

ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}

OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
} // namespace kernel

template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
Comment on lines +77 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: template parameter IType name suggests input type, but it is used as OType (output) inside reduce_dbias_kernel; consider renaming to OType for clarity

cudaStream_t stream) {
using namespace kernel;
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);

NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec);

reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace common
} // namespace dispatch
} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
Loading