Flash Attention#
+TileLang FlashAttention Tutorial#
+ +Part 1: Forward Pass#
+ +1. Overview#
+This tutorial explains the forward pass implementation of FlashAttention in TileLang, a parallel programming language for high-performance GPU computing. The kernel achieves >1.3x speedup over FlashAttention 2, reaching 630 TFLOPS/s for 4K sequence lengths on NVIDIA H100 GPUs. Key optimizations include:
+-
+
Tiled computation with configurable block sizes
+Pipelined memory operations to hide latency
+Numerically stable softmax using base-2 exponentials
+Split-KV attention for efficient memory usage
+Autotuning support for optimal performance
+
Note
+This is Part 1 of the complete FlashAttention tutorial, focusing on the forward pass implementation. Part 2 will cover the backward pass and gradient computation.
+2. Forward Pass Kernel Code#
+@autotune(configs=get_configs())
+@jit(out_idx=[-2, -1])
+def mha_fwd(
+ batch,
+ seq_len,
+ seq_len_kv,
+ heads,
+ dim,
+ sm_scale,
+ is_causal,
+ block_M=128,
+ block_N=128,
+ num_stages=1,
+ threads=256,
+):
+ sm_scale = (1.0 / dim)**0.5 * 1.44269504
+
+ q_shape = [batch, heads, seq_len, dim]
+ k_shape = [batch, heads, seq_len_kv, dim]
+ v_shape = [batch, heads, seq_len_kv, dim]
+ o_shape = [batch, heads, seq_len, dim]
+ lse_shape = [batch, heads, seq_len]
+ dtype = "bfloat16"
+ accum_dtype = "float"
+ q_start_id = seq_len_kv - seq_len
+
+ @T.prim_func
+ def main(
+ Q: T.Tensor(q_shape, dtype),
+ K: T.Tensor(k_shape, dtype),
+ V: T.Tensor(v_shape, dtype),
+ Output: T.Tensor(o_shape, dtype),
+ Lse: T.Tensor(lse_shape, accum_dtype),
+ ):
+ with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
+ # Memory allocations
+ Q_shared = T.alloc_shared([block_M, dim], dtype)
+ K_shared = T.alloc_shared([block_N, dim], dtype)
+ V_shared = T.alloc_shared([block_N, dim], dtype)
+ O_shared = T.alloc_shared([block_M, dim], dtype)
+ acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
+ acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
+ acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
+ scores_max = T.alloc_fragment([block_M], accum_dtype)
+ scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
+ scores_scale = T.alloc_fragment([block_M], accum_dtype)
+ scores_sum = T.alloc_fragment([block_M], accum_dtype)
+ logsum = T.alloc_fragment([block_M], accum_dtype)
+
+ # Load Q tile
+ T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :dim], Q_shared)
+ T.fill(acc_o, 0)
+ T.fill(logsum, 0)
+ T.fill(scores_max, -2**30)
+
+ # Determine loop range with causal masking
+ loop_range = (
+ T.min(T.ceildiv(seq_len_kv, block_N),
+ T.ceildiv(seq_len_kv - seq_len + (bx + 1) * block_M, block_N))
+ if is_causal else T.ceildiv(seq_len_kv, block_N))
+
+ # Main attention loop
+ for k in T.Pipelined(loop_range, num_stages=num_stages):
+ # Load K tile
+ T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :dim], K_shared)
+
+ # Initialize scores with causal masking
+ for i, j in T.Parallel(block_M, block_N):
+ q_idx = bx * block_M + i + q_start_id
+ k_idx = k * block_N + j
+ acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
+
+ # Compute attention scores
+ T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+
+ # Softmax preparation
+ T.copy(scores_max, scores_max_prev)
+ T.reduce_max(acc_s, scores_max, dim=1, clear=False)
+ for i in T.Parallel(block_M):
+ scores_scale[i] = T.exp2(scores_max_prev[i]*sm_scale - scores_max[i]*sm_scale)
+
+ # Compute exp scores
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.exp2(acc_s[i, j]*sm_scale - scores_max[i]*sm_scale)
+ T.reduce_sum(acc_s, scores_sum, dim=1)
+
+ # Update running statistics
+ for i in T.Parallel(block_M):
+ logsum[i] = logsum[i]*scores_scale[i] + scores_sum[i]
+ T.copy(acc_s, acc_s_cast)
+ for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] *= scores_scale[i]
+
+ # Accumulate output
+ T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
+ T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
+
+ # Final normalization
+ for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] /= logsum[i]
+ T.copy(acc_o, O_shared)
+ T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
+
+ # Save logsumexp
+ for i in T.Parallel(block_M):
+ logsum[i] = T.log2(logsum[i]) + scores_max[i] * sm_scale
+ T.copy(logsum, Lse[bz, by, bx * block_M:(bx + 1) * block_M])
+
+ return main
+
3. Forward Pass Components Explained#
+ +Kernel Configuration#
+@autotune(configs=get_configs())
+@jit(out_idx=[-2, -1])
+
-
+
Autotuning: Dynamically selects optimal tile sizes (
block_M
,block_N
) and thread configuration
+Output Indexing:
out_idx=[-2,-1]
specifies tensor contraction dimensions
+
Parameters#
+sm_scale = (1.0 / dim)**0.5 * 1.44269504
+
-
+
Precomputed Scaling: Combines normalization factor (√dk) and log₂(e) constant for softmax
+
Memory Management#
+Q_shared = T.alloc_shared([block_M, dim], dtype)
+K_shared = T.alloc_shared([block_N, dim], dtype)
+
-
+
Tiled Loading: Copies blocks of Q/K/V into shared memory
+
Causal Masking#
+acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
+
-
+
Block-Level Masking: Computes valid attention ranges per tile
+Position Tracking: Uses
q_start_id
to align sequences in encoder-decoder setups
+
Attention Computation#
+T.gemm(Q_shared, K_shared, acc_s, transpose_B=True)
+
Numerically Stable Softmax#
+# Compute max per row
+T.reduce_max(acc_s, scores_max, dim=1)
+
+# Rescale previous statistics
+scores_scale[i] = T.exp2(scores_max_prev[i]*sm_scale - scores_max[i]*sm_scale)
+
+# Compute exp scores
+acc_s[i, j] = T.exp2(acc_s[i, j]*sm_scale - scores_max[i]*sm_scale)
+
+# Update running sum
+logsum[i] = logsum[i]*scores_scale[i] + scores_sum[i]
+
-
+
Base-2 Exponentiation: Leverages hardware-optimized
exp2
instructions
+Online Rescaling: Maintains numerical stability across blocks
+Incremental Updates: Modifies output accumulation during softmax
+
Output Accumulation#
+T.gemm(acc_s_cast, V_shared, acc_o)
+for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] /= logsum[i]
+
-
+
Fused GEMM: Multiplies softmax output with value matrix
+Delayed Normalization: Applies softmax scaling after full accumulation
+
Pipelining#
+for k in T.Pipelined(loop_range, num_stages=num_stages):
+
-
+
Overlapped Execution: Hides memory latency by processing multiple blocks concurrently
+Configurable Stages:
num_stages
controls pipeline depth for different hardware
+
Log-Sum-Exp Output#
+logsum[i] = T.log2(logsum[i]) + scores_max[i]*sm_scale
+T.copy(logsum, Lse)
+
-
+
Final Logarithm: Converts running sum to log-space
+Gradient Preparation: Stores values needed for backward pass
+
4. Performance Notes#
+-
+
Block Sizing: Optimal
block_M
/block_N
vary with sequence length (128 works well for 4K)
+Thread Count: 256 threads balance occupancy and resource usage
+Mixed Precision:
bfloat16
for storage withfloat
accumulation maintains precision
+Causal Adaptation: Loop range adjustment minimizes unnecessary computation
+