Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ void compare_nvfp4_tensors(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8) {
constexpr bool print_detailed_summary = false;
std::vector<std::string> mismatch_messages;
size_t total_mismatches = 0;

Expand Down Expand Up @@ -381,36 +382,42 @@ void compare_nvfp4_tensors(const std::string& name,
std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg);

// Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) {
std::cout << "Error in tensor " << name << ": " << msg << std::endl;
if constexpr (print_detailed_summary) {
mismatch_messages.push_back(msg);

// Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) {
std::cout << "Error in tensor " << name << ": " << msg << std::endl;
}
} else {
GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name;
}
}
}
}
}

// Always report summary - either success or failure
std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl;
std::cout << "Total elements checked: " << (rows * cols) << std::endl;

if (total_mismatches > 0) {
std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
if constexpr (print_detailed_summary) {
// Always report summary - either success or failure
std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl;
std::cout << "Total elements checked: " << (rows * cols) << std::endl;

if (total_mismatches > 0) {
std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
}
std::cout << "============================" << std::endl;

GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name;
} else {
std::cout << "STATUS: PASSED for output" << std::endl;
std::cout << "All elements match within tolerance!" << std::endl;
std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl;
std::cout << "============================" << std::endl;
}
std::cout << "============================" << std::endl;

GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name;
} else {
std::cout << "STATUS: PASSED for output" << std::endl;
std::cout << "All elements match within tolerance!" << std::endl;
std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl;
std::cout << "============================" << std::endl;
}
}

Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return cols % alignment_requirement == 0;
}

__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char *>(addr);
}

namespace kernel {

constexpr size_t THREADS_PER_BLOCK = 256;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
#include "specialized/quantize_transpose_nvfp4_persistent_1D.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -1159,6 +1160,12 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
#if FP4_TYPE_SUPPORTED
using namespace quantize_transpose_kernel;
using namespace ptx;

if (!use_2d_quantization && input.dtype() == DType::kBFloat16) {
quantize_transpose_persistent_1D(input, noop, output, quant_config, stream);
return;
}

bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;

// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
Expand Down
Loading
Loading