Skip to content

Commit 4ac9df6

Browse files
committed
numerical test passed
Signed-off-by: Zhongbo Zhu <[email protected]>
1 parent c5bef0b commit 4ac9df6

File tree

2 files changed

+89
-45
lines changed

2 files changed

+89
-45
lines changed

transformer_engine/common/swizzle/swizzle.cu

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
332332
} // namespace
333333

334334
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
335-
NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING ||
336-
input->scaling_mode == NVTE_BLOCK_SCALING_1D ||
337-
input->scaling_mode == NVTE_BLOCK_SCALING_2D ||
338-
input->scaling_mode == NVTE_NVFP4_1D_SCALING,
339-
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
335+
NVTE_CHECK(
336+
input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING,
337+
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
340338
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
341339
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
342340

@@ -583,16 +581,19 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
583581
NVTE_CHECK_CUDA(cudaGetLastError());
584582
}
585583

586-
// TODO(nvfp4): Add NVFP4 support.
587584
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
588585
std::vector<Tensor*>& output, cudaStream_t stream) {
589586
auto num_tensors = input.size();
590587
bool all_has_data = true;
591588
bool all_has_columnwise_data = true;
589+
bool all_nvfp4 = true;
592590
for (size_t i = 0; i < num_tensors; i++) {
593-
if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) {
594-
NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + ".");
595-
}
591+
auto scaling_mode = input[i]->scaling_mode;
592+
auto is_fp8 = is_fp8_dtype(input[i]->dtype());
593+
auto is_fp4 = is_fp4_dtype(input[i]->dtype());
594+
NVTE_CHECK(
595+
(is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
596+
"Not implemented scaling mode " + to_string(scaling_mode) + ".");
596597
// We don't allow empty tensors. They should be filtered out before calling this function.
597598
if (input[i]->data.numel() == 0) {
598599
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
@@ -601,13 +602,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
601602
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
602603
all_has_data &= input[i]->has_data();
603604
all_has_columnwise_data &= input[i]->has_columnwise_data();
605+
all_nvfp4 &= is_nvfp4_scaling(scaling_mode);
604606
}
605607
NVTE_CHECK(all_has_data || all_has_columnwise_data,
606608
"All tensors should have data or columnwise data.");
607609

610+
const bool rowwise_swizzle = all_has_data || all_nvfp4;
611+
const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;
612+
608613
constexpr int SF_TILE_DIM_M = 128;
609614
constexpr int SF_TILE_DIM_K = 4;
610-
if (all_has_data) {
615+
if (rowwise_swizzle) {
611616
MultiSwizzleArgs kernel_args;
612617
kernel_args.num_tensors = 0;
613618
kernel_args.block_range[0] = 0;
@@ -623,29 +628,56 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
623628
kernel_args.num_tensors = 0;
624629
vec_load_size = 4;
625630
}
626-
const int m = input[i]->scale_inv.shape[0];
627-
const int k = input[i]->scale_inv.shape[1];
631+
632+
int m, k;
633+
634+
if (all_has_data) {
635+
m = input[i]->scale_inv.shape[0];
636+
k = input[i]->scale_inv.shape[1];
637+
} else {
638+
NVTE_CHECK(all_nvfp4, "When doing rowwise swizzle with rowwise data, it has to be NVFP4");
639+
m = input[i]->columnwise_scale_inv.shape[0];
640+
k = input[i]->columnwise_scale_inv.shape[1];
641+
}
628642

629643
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
630644
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
631645
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
632-
NVTE_CHECK(
633-
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
634-
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
635-
"Input.scale_inv size is not equal to Output.scale_inv size!");
646+
647+
if (output[i]->has_data()) {
648+
NVTE_CHECK(
649+
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
650+
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
651+
"Input.scale_inv size is not equal to Output.scale_inv size!");
652+
}
653+
if (output[i]->has_columnwise_data()) {
654+
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
655+
output[i]->columnwise_scale_inv.shape.end(), 1,
656+
std::multiplies<int>()),
657+
"Input.columnwise_scale_inv size is not equal to "
658+
"Output.columnwise_scale_inv size!");
659+
}
636660

637661
int num_tiles_k = k / SF_TILE_DIM_K;
638662
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
639663
// We use the minimum vec_load_size across all tensors.
640664
vec_load_size = std::min(vec_load_size, vec_load_size_i);
641665

642666
const int pos = kernel_args.num_tensors;
643-
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
644-
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
645667
kernel_args.m_list[pos] = m;
646668
kernel_args.k_list[pos] = k;
647-
kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
648-
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE;
669+
if (!all_nvfp4 || all_has_data) {
670+
int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
671+
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
672+
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
673+
kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
674+
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size;
675+
} else {
676+
kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
677+
kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
678+
kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
679+
kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE;
680+
}
649681
kernel_args.num_tensors++;
650682
}
651683
// Launch the remaining tensors
@@ -655,7 +687,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
655687
kernel_args, vec_load_size, true, stream);
656688
}
657689

658-
if (all_has_columnwise_data) {
690+
if (columnwise_swizzle) {
691+
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
692+
NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
693+
659694
MultiSwizzleArgs kernel_args;
660695
kernel_args.num_tensors = 0;
661696
kernel_args.block_range[0] = 0;

transformer_engine/pytorch/csrc/util.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,14 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
9999

100100
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
101101
NVTE_ERROR("Invalid scaling mode for swizzle.");
102-
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) {
102+
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
103+
tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
103104
return std::nullopt;
104105
}
105106

107+
const auto scaling_mode = tensors.front().scaling_mode();
108+
const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
109+
106110
std::vector<transformer_engine::TensorWrapper> wrappers;
107111
std::vector<NVTETensor> input_tensors, output_tensors;
108112

@@ -130,39 +134,44 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
130134
// Allocate full buffer
131135
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
132136

137+
const auto input_dtype =
138+
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
139+
const auto scale_inv_dtype =
140+
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
141+
133142
for (size_t i = 0; i < tensors.size(); ++i) {
134143
auto& tensor = tensors[i];
135144
void* scale_inv_dptr = scale_inv_dptrs[i];
136145
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
137-
auto input_shape = nvte_shape_to_vector(tensor.shape());
138-
146+
// auto input_shape = nvte_shape_to_vector(tensor.shape());
147+
NVTEShape nvte_input_shape;
148+
if (rowwise) {
149+
nvte_input_shape = tensor.shape();
150+
} else {
151+
nvte_input_shape = tensor.get_columnwise_data().shape;
152+
}
153+
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
139154
// Reconstruct input only to avoid swizzling both directions if not needed.
140155
// Use any 8 bit type, it's irrelevant.
141-
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
142-
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
156+
transformer_engine::TensorWrapper input_cu(scaling_mode);
157+
transformer_engine::TensorWrapper output_cu(scaling_mode);
143158
if (rowwise) {
144-
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
145-
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
146-
scale_inv_shapes[i]);
147-
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3,
148-
input_shape);
149-
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
150-
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
159+
input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
160+
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
161+
output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
162+
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
163+
scale_inv_shapes[i]);
151164
// Set the swizzled scaling factor to the original tensor.
152-
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
153-
scale_inv_shapes[i]);
165+
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
154166
} else {
155-
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
156-
input_shape);
157-
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
158-
scale_inv_shapes[i]);
159-
output_cu.set_columnwise_data(tensor.columnwise_dptr(),
160-
transformer_engine::DType::kFloat8E4M3, input_shape);
161-
output_cu.set_columnwise_scale_inv(
162-
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
167+
input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
168+
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
169+
output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
170+
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
171+
scale_inv_shapes[i]);
163172
// Set the swizzled scaling factor to the original tensor.
164-
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
165-
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
173+
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
174+
scale_inv_shapes[i]);
166175
}
167176

168177
input_tensors.emplace_back(input_cu.data());

0 commit comments

Comments
 (0)