@@ -37,8 +37,8 @@ namespace quantization_and_transposition_SF {
3737#if FP4_TYPE_SUPPORTED
3838// Used in transpose variant
3939// Compute per-block E4M3 encoding/decoding scaling factor
40- __device__ __forceinline__ nvfp4_scale_t
41- compute_decoding_scaling_factor ( const float block_amax, const float S_enc) {
40+ __device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor ( const float block_amax,
41+ const float S_enc) {
4242 // constexpr float rcp_6f = 1.0f / 6.0f;
4343 // const float S_dec_b = block_amax * rcp_6f;
4444 // const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
@@ -51,24 +51,24 @@ compute_decoding_scaling_factor(const float block_amax, const float S_enc) {
5151 return static_cast <nvfp4_scale_t >(fminf (S_dec_b, TypeExtrema<float >::max));
5252}
5353#else
54- NVTE_ERROR (" FP4 support requires CUDA 12.8+, but compile-time CUDA version is " , CUDA_VERSION);
54+ NVTE_ERROR (" FP4 support requires CUDA 12.8+, but compile-time CUDA version is " , CUDA_VERSION);
5555#endif // FP4_TYPE_SUPPORTED
5656} // namespace quantization_and_transposition_SF
5757
5858namespace quantization_SF {
5959#if FP4_TYPE_SUPPORTED
6060// Used in non-transpose variant
6161// Compute per-block E4M3 encoding/decoding scaling factor
62- __device__ __forceinline__ fp8e4m3
63- compute_decoding_scaling_factor ( const float block_amax, const float S_enc) {
62+ __device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor ( const float block_amax,
63+ const float S_enc) {
6464 constexpr float rcp_6f = 1 .0f / 6 .0f ;
6565 // const float S_dec_b = block_amax * rcp_6f;
6666 // const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
6767 // return S_dec_b_fp8;
6868 return static_cast <fp8e4m3>(block_amax * rcp_6f * S_enc);
6969}
7070#else
71- NVTE_ERROR (" FP4 support requires CUDA 12.8+, but compile-time CUDA version is " , CUDA_VERSION);
71+ NVTE_ERROR (" FP4 support requires CUDA 12.8+, but compile-time CUDA version is " , CUDA_VERSION);
7272#endif // FP4_TYPE_SUPPORTED
7373} // namespace quantization_SF
7474
@@ -82,8 +82,7 @@ using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::Phi
8282using namespace ptx ;
8383
8484// Compute the global encode scale factor for a given global amax
85- __device__ __forceinline__ float
86- compute_global_encode_scaling_factor_FP4 (const float global_amax) {
85+ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4 (const float global_amax) {
8786 using namespace detail ;
8887 constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
8988 constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
@@ -97,8 +96,7 @@ compute_global_encode_scaling_factor_FP4(const float global_amax) {
9796 return global_encode_scale;
9897}
9998
100- __device__ __forceinline__ uint32_t
101- get_rbits (RNG &rng, uint4 &random_uint4, int &rnd_idx) {
99+ __device__ __forceinline__ uint32_t get_rbits (RNG &rng, uint4 &random_uint4, int &rnd_idx) {
102100 if (rnd_idx == 4 ) {
103101 rnd_idx = 0 ;
104102 curanddx::uniform_bits dist;
@@ -110,9 +108,8 @@ get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
110108 return rbits;
111109}
112110
113- __device__ __forceinline__ fp4e2m1x4
114- mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding (const uint64_t in_4x, const float2 scale,
115- const uint32_t rbits) {
111+ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding (
112+ const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
116113 uint16_t out_4x = 0 ;
117114 constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
118115 if constexpr (has_rs) {
@@ -144,14 +141,16 @@ mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(const uint64_t in_4x, const floa
144141 : " =h" (out_4x)
145142 : " l" (in_4x), " l" (reinterpret_cast <const uint64_t &>(scale)), " r" (rbits));
146143 } else {
147- NVTE_DEVICE_ERROR (" FP4 cvt PTX instructions are architecture-specific. "
148- " Try recompiling with sm_XXXa instead of sm_XXX." );
144+ NVTE_DEVICE_ERROR (
145+ " FP4 cvt PTX instructions are architecture-specific. "
146+ " Try recompiling with sm_XXXa instead of sm_XXX." );
149147 }
150148 return *reinterpret_cast <fp4e2m1x4 *>(&out_4x);
151149}
152150
153- __device__ __forceinline__ fp4e2m1x4
154- mul_cvt_bf16_to_fp4_4x_with_rn (const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
151+ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn (const uint64_t in_4x,
152+ const float2 scale,
153+ const uint32_t rbits) {
155154 constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
156155 uint32_t out_4x = 0 ; // Only need 16 bit. Using 32 bit container for packing.
157156 if constexpr (is_blackwell) {
@@ -188,25 +187,26 @@ mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const u
188187 : " =r" (out_4x)
189188 : " l" (in_4x), " l" (reinterpret_cast <const uint64_t &>(scale)));
190189 } else {
191- NVTE_DEVICE_ERROR (" FP4 cvt PTX instructions are architecture-specific. "
192- " Try recompiling with sm_XXXa instead of sm_XXX." );
190+ NVTE_DEVICE_ERROR (
191+ " FP4 cvt PTX instructions are architecture-specific. "
192+ " Try recompiling with sm_XXXa instead of sm_XXX." );
193193 }
194194 return reinterpret_cast <fp4e2m1x4 *>(&out_4x)[0 ];
195195}
196196
197197template <bool USE_STOCHASTIC_ROUNDING>
198- __device__ __forceinline__ fp4e2m1x4
199- mul_cvt_bf16_to_fp4_4x (const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
198+ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x (const uint64_t in_4x,
199+ const float2 scale,
200+ const uint32_t rbits) {
200201 if constexpr (USE_STOCHASTIC_ROUNDING) {
201202 return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding (in_4x, scale, rbits);
202203 } else {
203204 return mul_cvt_bf16_to_fp4_4x_with_rn (in_4x, scale, rbits);
204205 }
205206}
206207
207- __device__ __forceinline__ fp4e2m1x4
208- mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding (const float2 in01, const float2 in23,
209- const float2 scale, const uint32_t rbits) {
208+ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding (
209+ const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
210210 uint16_t out_4x = 0 ;
211211 constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
212212 if constexpr (has_rs) {
@@ -233,15 +233,17 @@ mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(const float2 in01, const float2
233233 " l" (reinterpret_cast <const uint64_t &>(in23)),
234234 " l" (reinterpret_cast <const uint64_t &>(scale)), " r" (rbits));
235235 } else {
236- NVTE_DEVICE_ERROR (" FP4 cvt PTX instructions are architecture-specific. "
237- " Try recompiling with sm_XXXa instead of sm_XXX." );
236+ NVTE_DEVICE_ERROR (
237+ " FP4 cvt PTX instructions are architecture-specific. "
238+ " Try recompiling with sm_XXXa instead of sm_XXX." );
238239 }
239240 return *reinterpret_cast <fp4e2m1x4 *>(&out_4x);
240241}
241242
242- __device__ __forceinline__ fp4e2m1x4
243- mul_cvt_fp32_to_fp4_4x_with_rn (const float2 in01, const float2 in23, const float2 scale,
244- const uint32_t rbits) {
243+ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn (const float2 in01,
244+ const float2 in23,
245+ const float2 scale,
246+ const uint32_t rbits) {
245247 constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
246248 uint32_t out_4x = 0 ; // Only need 16 bit. Using 32 bit container for packing.
247249 if constexpr (is_blackwell) {
@@ -273,16 +275,17 @@ mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const float
273275 " l" (reinterpret_cast <const uint64_t &>(in23)),
274276 " l" (reinterpret_cast <const uint64_t &>(scale)));
275277 } else {
276- NVTE_DEVICE_ERROR (" FP4 cvt PTX instructions are architecture-specific. "
277- " Try recompiling with sm_XXXa instead of sm_XXX." );
278+ NVTE_DEVICE_ERROR (
279+ " FP4 cvt PTX instructions are architecture-specific. "
280+ " Try recompiling with sm_XXXa instead of sm_XXX." );
278281 }
279282 return reinterpret_cast <fp4e2m1x4 *>(&out_4x)[0 ];
280283}
281284
282285template <bool USE_STOCHASTIC_ROUNDING>
283- __device__ __forceinline__ fp4e2m1x4
284- mul_cvt_fp32_to_fp4_4x ( const float2 in01, const float2 in23, const float2 scale,
285- const uint32_t rbits) {
286+ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x ( const float2 in01, const float2 in23,
287+ const float2 scale,
288+ const uint32_t rbits) {
286289 if constexpr (USE_STOCHASTIC_ROUNDING) {
287290 return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding (in01, in23, scale, rbits);
288291 } else {
0 commit comments