Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions transformer_engine/common/util/ptx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -867,19 +867,19 @@ __device__ __forceinline__ void fma_f32_bf16(float &out, uint16_t const &a, uint
}

__device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const &in) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in));
#else
asm volatile(
"{\n\t"
".reg.b32 val;\n"
"abs.f32 val, %1;\n"
"redux.sync.max.u32 %0, val, 0xFFFFFFFF;\n"
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "f"(in));
#endif
constexpr bool is_sm_100f = NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>);
if constexpr (is_sm_100f) {
asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in));
} else {
asm volatile(
"{\n\t"
".reg.b32 val;\n"
"abs.f32 val, %1;\n"
"redux.sync.max.u32 %0, val, 0xFFFFFFFF;\n"
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "f"(in));
}
}

__device__ __forceinline__ bf16 get_amax(bf16 a, bf16 b) {
Expand Down