Skip to content
Open
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener Nov 12, 2025
51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 13, 2025
ae35e4c
bugfix graph capture
matthiasdiener Nov 13, 2025
77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 17, 2025
c0d8e73
outline workspace allocation
matthiasdiener Nov 17, 2025
6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 18, 2025
3c9de07
Proper allocation of workspace
matthiasdiener Nov 18, 2025
91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener Nov 19, 2025
bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener Nov 19, 2025
8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
73c8d4e
2-stage Triton amax
matthiasdiener Nov 20, 2025
6388604
add copyrights
matthiasdiener Nov 20, 2025
9e6586f
don't add extra template to kernel
matthiasdiener Nov 20, 2025
18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener Nov 21, 2025
a389455
update remaining calls to nvte_compute_amax
matthiasdiener Nov 21, 2025
d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 24, 2025
7d9ee16
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Nov 24, 2025
fd5dead
additional copyrights
matthiasdiener Nov 24, 2025
16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener Nov 24, 2025
50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 25, 2025
ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener Nov 25, 2025
f933ef3
Factor workspace allocation into function
matthiasdiener Nov 25, 2025
7d4054e
expand test slightly
matthiasdiener Nov 25, 2025
63cff98
Revert "expand test slightly"
Nov 25, 2025
c7d44a7
guard by HIP macro, address review comments
matthiasdiener Nov 26, 2025
f92b926
bugfix workspace.data.dptr
matthiasdiener Nov 26, 2025
eba552e
various cleanups
matthiasdiener Nov 26, 2025
0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 26, 2025
19901a0
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Nov 26, 2025
8eda427
simplify types in allocate_amax_workspace
matthiasdiener Nov 26, 2025
be6496b
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Nov 26, 2025
ed1a54b
Fixes
matthiasdiener Nov 26, 2025
c8d5bb4
add support for NVTE_USE_ATOMIC_AMAX
matthiasdiener Nov 26, 2025
5a9086a
Fuse amax_reduce + compute_scale kernels
matthiasdiener Nov 26, 2025
6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
9ee618f
fix indentation
matthiasdiener Dec 1, 2025
853bb77
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Dec 1, 2025
cf402b1
undo non-triton changes
matthiasdiener Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 143 additions & 19 deletions transformer_engine/pytorch/triton_kernels/cast_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
te_dtype_to_torch_dtype,
get_fp8_max,
)
import os

##########################################
#### cast_transpose
##########################################
Expand Down Expand Up @@ -189,6 +191,101 @@ def _cast_transpose_triton_current_scaling(A, C, T, stride_am, stride_an, stride
tl.store(T, fp8_a, mask=mask)


AMAX_STAGE1_CONFIGS = [
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=8),
]

@triton.autotune(
configs=AMAX_STAGE1_CONFIGS,
key=['M', 'N'],
)
@triton.jit
def _amax_reduce_triton_stage1(
A,
stride_am, stride_an,
M, N,
block_amax, # float32[workspace_size]
num_blocks, # int32[1]
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)

grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N

width = GROUP_M * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size

rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)

A_ptrs = A + rm[:, None] * stride_am + rn[None, :] * stride_an
mask = (rm < M)[:, None] & (rn < N)[None, :]

a = tl.load(A_ptrs, mask=mask, other=0).to(tl.float32)
tile_amax = tl.max(tl.abs(a))

# Store per-program amax in workspace
tl.store(block_amax + pid, tile_amax)

if pid == 0:
tl.store(num_blocks, tl.num_programs(0))

@triton.jit
def _amax_reduce_and_compute_scale_triton(
block_amax, # float32[num_blocks]
num_blocks, # int32[1]
amax_ptr, # float32[1]
scale_ptr, # float32[1]
inv_ptr, # float32[1]
max_fp8, # scalar (float32)
epsilon, # scalar (float32)
value_for_inf, # scalar (float32)
FORCE_POW_2_SCALES: tl.constexpr,
BLOCKSIZE: tl.constexpr,
):
# Reduce per-block amaxes
a = tl.full((), -float('inf'), tl.float32)
offset = 0
num_blocks = tl.load(num_blocks)

while offset < num_blocks:
idx = offset + tl.arange(0, BLOCKSIZE)
mask = idx < num_blocks
vals = tl.load(block_amax + idx, mask=mask, other=-float('inf'))
a = tl.maximum(a, tl.max(vals))
offset += BLOCKSIZE

tl.store(amax_ptr, a)

# Compute scale + inv_scale from amax

# amax < epsilon -> epsilon (NaNs pass through)
a = tl.where(a < epsilon, epsilon, a)

# bad amax (NaN, inf, 0.0) -> scale = 1.0
bad = (a != a) | (tl.abs(a) == float('inf')) | (a == 0.0)

if bad:
s = tl.full((), 1.0, tl.float32)
else:
s = max_fp8 / a
# inf -> scale = value_for_inf
s = tl.where(tl.abs(a) == float('inf'), value_for_inf, s)
if FORCE_POW_2_SCALES:
s = tl.math.exp2(tl.floor(tl.log2(s)))

tl.store(scale_ptr, s)
tl.store(inv_ptr, 1.0 / s)


FP32_EXPONENT_BIAS = tl.constexpr(127)
FP32_MANTISSA_BITS = tl.constexpr(23)
@triton.jit
Expand Down Expand Up @@ -376,28 +473,55 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans
grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_M']) * triton.cdiv(row_length, META['BLOCK_N']),)

if current_scaling:
# Current scaling:
# 1) global amax reduction
# 2) compute current scale
# 3) cast+transpose with that current scale (otherwise same as delayed)

# global amax
amax_out.fill_(-float("inf"))
_amax_reduce_triton[grid](
input_2d_view,
input_stride_M, input_stride_N,
num_rows, row_length,
amax_out,
)

# Compute scale
fp8_max = get_fp8_max(otype)

_compute_scale_from_amax_triton[(1,)](
amax_out, input_scale, scale_inv_out,
fp8_max, eps, torch.finfo(torch.float32).max,
FORCE_POW_2_SCALES=force_pow_2_scales,
)
nvte_use_atomic_amax = bool( int(os.environ.get('NVTE_USE_ATOMIC_AMAX', '0')) )

if nvte_use_atomic_amax:
# Compute global amax
_amax_reduce_triton[grid](
input_2d_view,
input_stride_M, input_stride_N,
num_rows, row_length,
amax_out,
)

# Compute scale
_compute_scale_from_amax_triton[(1,)](
amax_out, input_scale, scale_inv_out,
fp8_max, eps, torch.finfo(torch.float32).max,
FORCE_POW_2_SCALES=force_pow_2_scales,
)
else:
# 2-stage amax
max_num_amax_stage1_programs = max(
triton.cdiv(num_rows, cfg.kwargs['BLOCK_M']) *
triton.cdiv(row_length, cfg.kwargs['BLOCK_N'])
for cfg in AMAX_STAGE1_CONFIGS
)

block_amax = torch.empty(max_num_amax_stage1_programs, device=input.device,
dtype=torch.float32)

num_blocks = torch.empty(1, device=input.device, dtype=torch.int32)

# Stage 1: per-program tile amax
_amax_reduce_triton_stage1[grid](
input_2d_view,
input_stride_M, input_stride_N,
num_rows, row_length,
block_amax, num_blocks,
)

# Stage 2: reduce per-program maxima into amax_out and compute scale
_amax_reduce_and_compute_scale_triton[(1,)](
block_amax, num_blocks,
amax_out, input_scale, scale_inv_out,
fp8_max, eps, torch.finfo(torch.float32).max,
FORCE_POW_2_SCALES=force_pow_2_scales,
BLOCKSIZE=512,
)

_cast_transpose_triton_current_scaling[grid](input_2d_view, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, get_fp8_max(otype))
else:
Expand Down