Skip to content

Commit 37747dd

Browse files
pre-commit-ci[bot]Oleg-Goncharov
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f7225e9 commit 37747dd

File tree

1 file changed

+37
-34
lines changed

1 file changed

+37
-34
lines changed

transformer_engine/common/cast/nvfp4/core_nvfp4.cuh

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ namespace quantization_and_transposition_SF {
3737
#if FP4_TYPE_SUPPORTED
3838
// Used in transpose variant
3939
// Compute per-block E4M3 encoding/decoding scaling factor
40-
__device__ __forceinline__ nvfp4_scale_t
41-
compute_decoding_scaling_factor(const float block_amax, const float S_enc) {
40+
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
41+
const float S_enc) {
4242
// constexpr float rcp_6f = 1.0f / 6.0f;
4343
// const float S_dec_b = block_amax * rcp_6f;
4444
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
@@ -51,24 +51,24 @@ compute_decoding_scaling_factor(const float block_amax, const float S_enc) {
5151
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
5252
}
5353
#else
54-
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
54+
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
5555
#endif // FP4_TYPE_SUPPORTED
5656
} // namespace quantization_and_transposition_SF
5757

5858
namespace quantization_SF {
5959
#if FP4_TYPE_SUPPORTED
6060
// Used in non-transpose variant
6161
// Compute per-block E4M3 encoding/decoding scaling factor
62-
__device__ __forceinline__ fp8e4m3
63-
compute_decoding_scaling_factor(const float block_amax, const float S_enc) {
62+
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
63+
const float S_enc) {
6464
constexpr float rcp_6f = 1.0f / 6.0f;
6565
// const float S_dec_b = block_amax * rcp_6f;
6666
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
6767
// return S_dec_b_fp8;
6868
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
6969
}
7070
#else
71-
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
71+
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
7272
#endif // FP4_TYPE_SUPPORTED
7373
} // namespace quantization_SF
7474

@@ -82,8 +82,7 @@ using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::Phi
8282
using namespace ptx;
8383

8484
// Compute the global encode scale factor for a given global amax
85-
__device__ __forceinline__ float
86-
compute_global_encode_scaling_factor_FP4(const float global_amax) {
85+
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
8786
using namespace detail;
8887
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
8988
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
@@ -97,8 +96,7 @@ compute_global_encode_scaling_factor_FP4(const float global_amax) {
9796
return global_encode_scale;
9897
}
9998

100-
__device__ __forceinline__ uint32_t
101-
get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
99+
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
102100
if (rnd_idx == 4) {
103101
rnd_idx = 0;
104102
curanddx::uniform_bits dist;
@@ -110,9 +108,8 @@ get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
110108
return rbits;
111109
}
112110

113-
__device__ __forceinline__ fp4e2m1x4
114-
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(const uint64_t in_4x, const float2 scale,
115-
const uint32_t rbits) {
111+
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
112+
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
116113
uint16_t out_4x = 0;
117114
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
118115
if constexpr (has_rs) {
@@ -144,14 +141,16 @@ mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(const uint64_t in_4x, const floa
144141
: "=h"(out_4x)
145142
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
146143
} else {
147-
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
148-
"Try recompiling with sm_XXXa instead of sm_XXX.");
144+
NVTE_DEVICE_ERROR(
145+
"FP4 cvt PTX instructions are architecture-specific. "
146+
"Try recompiling with sm_XXXa instead of sm_XXX.");
149147
}
150148
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
151149
}
152150

153-
__device__ __forceinline__ fp4e2m1x4
154-
mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
151+
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
152+
const float2 scale,
153+
const uint32_t rbits) {
155154
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
156155
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
157156
if constexpr (is_blackwell) {
@@ -188,25 +187,26 @@ mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const u
188187
: "=r"(out_4x)
189188
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
190189
} else {
191-
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
192-
"Try recompiling with sm_XXXa instead of sm_XXX.");
190+
NVTE_DEVICE_ERROR(
191+
"FP4 cvt PTX instructions are architecture-specific. "
192+
"Try recompiling with sm_XXXa instead of sm_XXX.");
193193
}
194194
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
195195
}
196196

197197
template <bool USE_STOCHASTIC_ROUNDING>
198-
__device__ __forceinline__ fp4e2m1x4
199-
mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
198+
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
199+
const float2 scale,
200+
const uint32_t rbits) {
200201
if constexpr (USE_STOCHASTIC_ROUNDING) {
201202
return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
202203
} else {
203204
return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
204205
}
205206
}
206207

207-
__device__ __forceinline__ fp4e2m1x4
208-
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(const float2 in01, const float2 in23,
209-
const float2 scale, const uint32_t rbits) {
208+
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
209+
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
210210
uint16_t out_4x = 0;
211211
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
212212
if constexpr (has_rs) {
@@ -233,15 +233,17 @@ mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(const float2 in01, const float2
233233
"l"(reinterpret_cast<const uint64_t &>(in23)),
234234
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
235235
} else {
236-
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
237-
"Try recompiling with sm_XXXa instead of sm_XXX.");
236+
NVTE_DEVICE_ERROR(
237+
"FP4 cvt PTX instructions are architecture-specific. "
238+
"Try recompiling with sm_XXXa instead of sm_XXX.");
238239
}
239240
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
240241
}
241242

242-
__device__ __forceinline__ fp4e2m1x4
243-
mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const float2 scale,
244-
const uint32_t rbits) {
243+
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
244+
const float2 in23,
245+
const float2 scale,
246+
const uint32_t rbits) {
245247
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
246248
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
247249
if constexpr (is_blackwell) {
@@ -273,16 +275,17 @@ mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const float
273275
"l"(reinterpret_cast<const uint64_t &>(in23)),
274276
"l"(reinterpret_cast<const uint64_t &>(scale)));
275277
} else {
276-
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
277-
"Try recompiling with sm_XXXa instead of sm_XXX.");
278+
NVTE_DEVICE_ERROR(
279+
"FP4 cvt PTX instructions are architecture-specific. "
280+
"Try recompiling with sm_XXXa instead of sm_XXX.");
278281
}
279282
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
280283
}
281284

282285
template <bool USE_STOCHASTIC_ROUNDING>
283-
__device__ __forceinline__ fp4e2m1x4
284-
mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, const float2 scale,
285-
const uint32_t rbits) {
286+
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
287+
const float2 scale,
288+
const uint32_t rbits) {
286289
if constexpr (USE_STOCHASTIC_ROUNDING) {
287290
return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
288291
} else {

0 commit comments

Comments
 (0)