Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,7 @@ std::vector<std::vector<size_t>> tensor_dims = {

// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
ActivationType::GeLU,
ActivationType::SiLU,
ActivationType::ReLU,
ActivationType::QGeLU,
ActivationType::SReLU,
Comment on lines -667 to -671
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support activation-cast kernels for NVFP4, we should test them somewhere. Maybe not in this file, but somewhere.

ActivationType::Identity
};

} // namespace
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ list(APPEND transformer_engine_SOURCES
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
cast/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
Expand Down Expand Up @@ -267,7 +267,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
cast/cast.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif()
Expand Down
21 changes: 8 additions & 13 deletions transformer_engine/common/activation/activation_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>

#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"

namespace transformer_engine {

Expand All @@ -32,8 +30,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias,
workspace, nullptr, stream);
}

template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
Expand All @@ -46,25 +44,22 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;

quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias,
workspace, nullptr, stream);
}

template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
dispatch::quantize_gated_helper<Param, ActOP>(input, output, p, stream);
}

template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
dispatch::quantize_dgated_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
}

} // namespace transformer_engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,13 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>

#include <cfloat>
#include <limits>
#include <mutex>
#include <string>

#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "dispatch/dequantize.cuh"
#include "dispatch/gated.cuh"
#include "dispatch/quantize.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"

Expand All @@ -38,8 +31,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
}

void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
Expand All @@ -66,7 +59,7 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
}

Expand All @@ -80,7 +73,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}

Expand All @@ -94,7 +87,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}

Expand All @@ -108,7 +101,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}

Expand All @@ -122,7 +115,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}

Expand All @@ -136,7 +129,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}

Expand All @@ -150,14 +143,15 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;

detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}

void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output),
stream);
}

void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
Expand All @@ -184,7 +178,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
}

for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
}
Expand Down
97 changes: 97 additions & 0 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file common.cuh
* \brief Common functions in quantize.
*/

#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_

#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>

#include "../../common.h"
#include "../../utils.cuh"

namespace transformer_engine {
namespace dispatch {
namespace common {
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
return isFullTile;
}

inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}

namespace kernel {

constexpr size_t THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;

const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id * nvec >= cols) {
return;
}

const float *const thread_in_base = dbias_partial + thread_id * nvec;
OType *const thread_out_base = dbias_output + thread_id * nvec;

ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}

OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
} // namespace kernel

template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) {
using namespace kernel;
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);

NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec);

reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace common
} // namespace dispatch
} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
53 changes: 53 additions & 0 deletions transformer_engine/common/cast/dispatch/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file dequantize.cuh
* \brief Dequantize dispatcher.
*/

#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_

#include <transformer_engine/transformer_engine.h>

#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#include "../nvfp4/dequantize_nvfp4.cuh"

namespace transformer_engine {
namespace dispatch {

inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");

switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
fp8::dequantize(input, output, stream);
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
mxfp8::dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
nvfp4::dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}

} // namespace dispatch
} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
Loading
Loading