Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return cols % alignment_requirement == 0;
}

__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char *>(addr);
}

namespace kernel {

constexpr size_t THREADS_PER_BLOCK = 256;
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
(cols % 32 == 0) && output_tensor->has_data() &&
is_supported_by_CC_100();

// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
#include "specialized/quantize_transpose_nvfp4_persistent_1D.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -1159,13 +1160,19 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
#if FP4_TYPE_SUPPORTED
using namespace quantize_transpose_kernel;
using namespace ptx;

bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;

// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();

if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && return_transpose) {
quantize_transpose_persistent_1D(input, noop, output, quant_config, stream);
return;
}

constexpr bool COMPUTE_ACTIVATIONS = false;
using ParamOP = Empty;
constexpr float (*OP)(float, const ParamOP &) = nullptr;
Expand Down
Loading
Loading