@@ -332,11 +332,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
332
332
} // namespace
333
333
334
334
void swizzle_scaling_factors (const Tensor* input, Tensor* output, cudaStream_t stream) {
335
- NVTE_CHECK (input->scaling_mode == NVTE_MXFP8_1D_SCALING ||
336
- input->scaling_mode == NVTE_BLOCK_SCALING_1D ||
337
- input->scaling_mode == NVTE_BLOCK_SCALING_2D ||
338
- input->scaling_mode == NVTE_NVFP4_1D_SCALING,
339
- " Input tensor has invalid scaling mode (" , to_string (input->scaling_mode ), " )." );
335
+ NVTE_CHECK (
336
+ input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING,
337
+ " Input tensor has invalid scaling mode (" , to_string (input->scaling_mode ), " )." );
340
338
NVTE_CHECK (is_fp8_dtype (input->dtype ()) || is_fp4_dtype (input->dtype ()),
341
339
" Input tensor has invalid dtype (" , to_string (input->dtype ()), " )." );
342
340
@@ -583,16 +581,19 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
583
581
NVTE_CHECK_CUDA (cudaGetLastError ());
584
582
}
585
583
586
- // TODO(nvfp4): Add NVFP4 support.
587
584
void multi_tensor_swizzle_scaling_factors (const std::vector<Tensor*>& input,
588
585
std::vector<Tensor*>& output, cudaStream_t stream) {
589
586
auto num_tensors = input.size ();
590
587
bool all_has_data = true ;
591
588
bool all_has_columnwise_data = true ;
589
+ bool all_nvfp4 = true ;
592
590
for (size_t i = 0 ; i < num_tensors; i++) {
593
- if (!is_fp8_dtype (input[i]->dtype ()) || !is_mxfp_scaling (input[i]->scaling_mode )) {
594
- NVTE_ERROR (" Not implemented caling mode " + to_string (input[i]->scaling_mode ) + " ." );
595
- }
591
+ auto scaling_mode = input[i]->scaling_mode ;
592
+ auto is_fp8 = is_fp8_dtype (input[i]->dtype ());
593
+ auto is_fp4 = is_fp4_dtype (input[i]->dtype ());
594
+ NVTE_CHECK (
595
+ (is_fp8 && is_mxfp8_scaling (scaling_mode)) || (is_fp4 && is_nvfp4_scaling (scaling_mode)),
596
+ " Not implemented scaling mode " + to_string (scaling_mode) + " ." );
596
597
// We don't allow empty tensors. They should be filtered out before calling this function.
597
598
if (input[i]->data .numel () == 0 ) {
598
599
NVTE_ERROR (" Tensor input[" + std::to_string (i) + " ] is empty." );
@@ -601,13 +602,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
601
602
CheckInputTensor (*output[i], " scaling_factor_output[" + std::to_string (i) + " ]" );
602
603
all_has_data &= input[i]->has_data ();
603
604
all_has_columnwise_data &= input[i]->has_columnwise_data ();
605
+ all_nvfp4 &= is_nvfp4_scaling (scaling_mode);
604
606
}
605
607
NVTE_CHECK (all_has_data || all_has_columnwise_data,
606
608
" All tensors should have data or columnwise data." );
607
609
610
+ const bool rowwise_swizzle = all_has_data || all_nvfp4;
611
+ const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;
612
+
608
613
constexpr int SF_TILE_DIM_M = 128 ;
609
614
constexpr int SF_TILE_DIM_K = 4 ;
610
- if (all_has_data ) {
615
+ if (rowwise_swizzle ) {
611
616
MultiSwizzleArgs kernel_args;
612
617
kernel_args.num_tensors = 0 ;
613
618
kernel_args.block_range [0 ] = 0 ;
@@ -623,29 +628,56 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
623
628
kernel_args.num_tensors = 0 ;
624
629
vec_load_size = 4 ;
625
630
}
626
- const int m = input[i]->scale_inv .shape [0 ];
627
- const int k = input[i]->scale_inv .shape [1 ];
631
+
632
+ int m, k;
633
+
634
+ if (all_has_data) {
635
+ m = input[i]->scale_inv .shape [0 ];
636
+ k = input[i]->scale_inv .shape [1 ];
637
+ } else {
638
+ NVTE_CHECK (all_nvfp4, " When doing rowwise swizzle with rowwise data, it has to be NVFP4" );
639
+ m = input[i]->columnwise_scale_inv .shape [0 ];
640
+ k = input[i]->columnwise_scale_inv .shape [1 ];
641
+ }
628
642
629
643
NVTE_CHECK (m % SF_TILE_DIM_M == 0 , " Input should be padded in M/N dimension!" );
630
644
NVTE_CHECK (k % SF_TILE_DIM_K == 0 , " Input should be padded in K dimension!" );
631
645
NVTE_CHECK (k > 0 , " Input scale inverse should be 2D!" );
632
- NVTE_CHECK (
633
- m * k == std::accumulate (output[i]->scale_inv .shape .begin (),
634
- output[i]->scale_inv .shape .end (), 1 , std::multiplies<int >()),
635
- " Input.scale_inv size is not equal to Output.scale_inv size!" );
646
+
647
+ if (output[i]->has_data ()) {
648
+ NVTE_CHECK (
649
+ m * k == std::accumulate (output[i]->scale_inv .shape .begin (),
650
+ output[i]->scale_inv .shape .end (), 1 , std::multiplies<int >()),
651
+ " Input.scale_inv size is not equal to Output.scale_inv size!" );
652
+ }
653
+ if (output[i]->has_columnwise_data ()) {
654
+ NVTE_CHECK (m * k == std::accumulate (output[i]->columnwise_scale_inv .shape .begin (),
655
+ output[i]->columnwise_scale_inv .shape .end (), 1 ,
656
+ std::multiplies<int >()),
657
+ " Input.columnwise_scale_inv size is not equal to "
658
+ " Output.columnwise_scale_inv size!" );
659
+ }
636
660
637
661
int num_tiles_k = k / SF_TILE_DIM_K;
638
662
int vec_load_size_i = (num_tiles_k - 1 ) % 4 + 1 ;
639
663
// We use the minimum vec_load_size across all tensors.
640
664
vec_load_size = std::min (vec_load_size, vec_load_size_i);
641
665
642
666
const int pos = kernel_args.num_tensors ;
643
- kernel_args.input_list [pos] = const_cast <void *>(input[i]->scale_inv .dptr );
644
- kernel_args.output_list [pos] = output[i]->scale_inv .dptr ;
645
667
kernel_args.m_list [pos] = m;
646
668
kernel_args.k_list [pos] = k;
647
- kernel_args.original_m_list [pos] = input[i]->flat_first_dim ();
648
- kernel_args.original_k_list [pos] = input[i]->flat_last_dim () / MXFP8_BLOCK_SIZE;
669
+ if (!all_nvfp4 || all_has_data) {
670
+ int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
671
+ kernel_args.input_list [pos] = const_cast <void *>(input[i]->scale_inv .dptr );
672
+ kernel_args.output_list [pos] = output[i]->scale_inv .dptr ;
673
+ kernel_args.original_m_list [pos] = input[i]->flat_first_dim ();
674
+ kernel_args.original_k_list [pos] = input[i]->flat_last_dim () / block_scale_size;
675
+ } else {
676
+ kernel_args.input_list [pos] = const_cast <void *>(input[i]->columnwise_scale_inv .dptr );
677
+ kernel_args.output_list [pos] = output[i]->columnwise_scale_inv .dptr ;
678
+ kernel_args.original_m_list [pos] = input[i]->flat_last_dim ();
679
+ kernel_args.original_k_list [pos] = input[i]->flat_first_dim () / NVFP4_BLOCK_SIZE;
680
+ }
649
681
kernel_args.num_tensors ++;
650
682
}
651
683
// Launch the remaining tensors
@@ -655,7 +687,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
655
687
kernel_args, vec_load_size, true , stream);
656
688
}
657
689
658
- if (all_has_columnwise_data) {
690
+ if (columnwise_swizzle) {
691
+ // NVFP4 shouldn't end up here because it only needs rowwise swizzle
692
+ NVTE_CHECK (!all_nvfp4, " NVFP4 shouldn't end up here because it only needs rowwise swizzle" );
693
+
659
694
MultiSwizzleArgs kernel_args;
660
695
kernel_args.num_tensors = 0 ;
661
696
kernel_args.block_range [0 ] = 0 ;
0 commit comments