Skip to content

Latest commit

 

History

History
483 lines (434 loc) · 48.1 KB

BLOCK_LOADS_LAYOUT.md

File metadata and controls

483 lines (434 loc) · 48.1 KB

Using Linear Layout for Intel 2D Block Loads

The Intel Xe2/Xe3 Triton backend relies on the Dot Product Accumulate Systolic (DPAS) instruction for cooperative subgroup mma operations. Cooperative load instructions are important for maximizing performance when using DPAS to accelerate mma. A typical GEMM kernel will generate several DPAS instructions for a subgroup. Even with a 2D block load instruction, loading the data for each individual DPAS instruction would be inefficient. Therefore, when lowering to LLVM the Intel Triton backend expands load instructions to create 2D block loads, which load data for all contiguous blocks used by DPAS instructions in a subgroup up to hardware limitations.

The lowering of tt.dot operations to DPAS instructions and tensor loads to 2D block loads starts by identifying and lowering the tensor loads. This lowering occurs during the Triton GPU to LLVM conversion pass. For each TTGIR::LoadOp operation we emit both a LLVM IR function call to the appropriate 2D block load instruction(s) and a set of shuffle vectors to transform the output of the block load into the required input for the DPAS instruction. This transformation is expected to take place only in registers and is represented using LLVM IR virtual registers.

The 2D block load size is primarily determined by the Tensor layout attached to the Triton tt.dot operation. However, there are numerous special cases and hardware limitations applied to the load. The algorithm works by starting with the DPAS required tile size and expanding the tile size based on the parameters of the DPAS layout, subject to hardware limitations. Below we present example layouts for a simple AxB and AxBT gemm kernel.

AxB

We will first consider the AxB GEMM kernel with A matrix dimension [1024, 5120] and B matrix dimension [5120, 4096]. The Triton kernel is directed to use block size 256 in the M and N dimensions and block size 32 in the K dimension (5120). The Triton backend automatically partitions the input and output matrices across workgroups/subgroups during the Triton GPU IR lowering phase.

A matrix

We will first consider the A matrix load, which will be the same for both the AxB and AxBT (transpose) cases as transpose does not currently change the DPAS layout. The tensor type (layout) generated by TTGIR dialect for the A matrix is:

tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 1}>>

Note that the tensor type describes the expected data layout going into and out of the tt.dot / DPAS instructions. While this type is attached to the load instruction, the data retrieved by the load may not match the layout described above. The lowering of the LoadOp to 2D blocked load also implies a conversion to make the loaded data match the desired tensor type. This conversion is instantiated using shuffle vectors.

We can use the triton-tensor-layout utility to print the DPAS layout with a hardware centric view (i.e. register/lane/warp mapping to tensor coordinates) using the following command:

./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-tensor-layout -l "#ttg.dot_op<{opIdx = 0, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 1}>" -t "tensor<256x32xf16>" -use-hw-view |& tee A_hw_view

which produces the following layout

Warp0:
(  0, 0), (  0, 1), (  0, 2), (  0, 3), (  0, 4), (  0, 5), (  0, 6), (  0, 7), (  0, 8), (  0, 9), (  0,10), (  0,11), (  0,12), (  0,13), (  0,14), (  0,15)
(  1, 0), (  1, 1), (  1, 2), (  1, 3), (  1, 4), (  1, 5), (  1, 6), (  1, 7), (  1, 8), (  1, 9), (  1,10), (  1,11), (  1,12), (  1,13), (  1,14), (  1,15)
(  2, 0), (  2, 1), (  2, 2), (  2, 3), (  2, 4), (  2, 5), (  2, 6), (  2, 7), (  2, 8), (  2, 9), (  2,10), (  2,11), (  2,12), (  2,13), (  2,14), (  2,15)
(  3, 0), (  3, 1), (  3, 2), (  3, 3), (  3, 4), (  3, 5), (  3, 6), (  3, 7), (  3, 8), (  3, 9), (  3,10), (  3,11), (  3,12), (  3,13), (  3,14), (  3,15)
(  4, 0), (  4, 1), (  4, 2), (  4, 3), (  4, 4), (  4, 5), (  4, 6), (  4, 7), (  4, 8), (  4, 9), (  4,10), (  4,11), (  4,12), (  4,13), (  4,14), (  4,15)
(  5, 0), (  5, 1), (  5, 2), (  5, 3), (  5, 4), (  5, 5), (  5, 6), (  5, 7), (  5, 8), (  5, 9), (  5,10), (  5,11), (  5,12), (  5,13), (  5,14), (  5,15)
(  6, 0), (  6, 1), (  6, 2), (  6, 3), (  6, 4), (  6, 5), (  6, 6), (  6, 7), (  6, 8), (  6, 9), (  6,10), (  6,11), (  6,12), (  6,13), (  6,14), (  6,15)
(  7, 0), (  7, 1), (  7, 2), (  7, 3), (  7, 4), (  7, 5), (  7, 6), (  7, 7), (  7, 8), (  7, 9), (  7,10), (  7,11), (  7,12), (  7,13), (  7,14), (  7,15)
(  8, 0), (  8, 1), (  8, 2), (  8, 3), (  8, 4), (  8, 5), (  8, 6), (  8, 7), (  8, 8), (  8, 9), (  8,10), (  8,11), (  8,12), (  8,13), (  8,14), (  8,15)
(  9, 0), (  9, 1), (  9, 2), (  9, 3), (  9, 4), (  9, 5), (  9, 6), (  9, 7), (  9, 8), (  9, 9), (  9,10), (  9,11), (  9,12), (  9,13), (  9,14), (  9,15)
( 10, 0), ( 10, 1), ( 10, 2), ( 10, 3), ( 10, 4), ( 10, 5), ( 10, 6), ( 10, 7), ( 10, 8), ( 10, 9), ( 10,10), ( 10,11), ( 10,12), ( 10,13), ( 10,14), ( 10,15)
( 11, 0), ( 11, 1), ( 11, 2), ( 11, 3), ( 11, 4), ( 11, 5), ( 11, 6), ( 11, 7), ( 11, 8), ( 11, 9), ( 11,10), ( 11,11), ( 11,12), ( 11,13), ( 11,14), ( 11,15)
( 12, 0), ( 12, 1), ( 12, 2), ( 12, 3), ( 12, 4), ( 12, 5), ( 12, 6), ( 12, 7), ( 12, 8), ( 12, 9), ( 12,10), ( 12,11), ( 12,12), ( 12,13), ( 12,14), ( 12,15)
( 13, 0), ( 13, 1), ( 13, 2), ( 13, 3), ( 13, 4), ( 13, 5), ( 13, 6), ( 13, 7), ( 13, 8), ( 13, 9), ( 13,10), ( 13,11), ( 13,12), ( 13,13), ( 13,14), ( 13,15)
( 14, 0), ( 14, 1), ( 14, 2), ( 14, 3), ( 14, 4), ( 14, 5), ( 14, 6), ( 14, 7), ( 14, 8), ( 14, 9), ( 14,10), ( 14,11), ( 14,12), ( 14,13), ( 14,14), ( 14,15)
( 15, 0), ( 15, 1), ( 15, 2), ( 15, 3), ( 15, 4), ( 15, 5), ( 15, 6), ( 15, 7), ( 15, 8), ( 15, 9), ( 15,10), ( 15,11), ( 15,12), ( 15,13), ( 15,14), ( 15,15)
( 16, 0), ( 16, 1), ( 16, 2), ( 16, 3), ( 16, 4), ( 16, 5), ( 16, 6), ( 16, 7), ( 16, 8), ( 16, 9), ( 16,10), ( 16,11), ( 16,12), ( 16,13), ( 16,14), ( 16,15)
( 17, 0), ( 17, 1), ( 17, 2), ( 17, 3), ( 17, 4), ( 17, 5), ( 17, 6), ( 17, 7), ( 17, 8), ( 17, 9), ( 17,10), ( 17,11), ( 17,12), ( 17,13), ( 17,14), ( 17,15)
( 18, 0), ( 18, 1), ( 18, 2), ( 18, 3), ( 18, 4), ( 18, 5), ( 18, 6), ( 18, 7), ( 18, 8), ( 18, 9), ( 18,10), ( 18,11), ( 18,12), ( 18,13), ( 18,14), ( 18,15)
( 19, 0), ( 19, 1), ( 19, 2), ( 19, 3), ( 19, 4), ( 19, 5), ( 19, 6), ( 19, 7), ( 19, 8), ( 19, 9), ( 19,10), ( 19,11), ( 19,12), ( 19,13), ( 19,14), ( 19,15)
( 20, 0), ( 20, 1), ( 20, 2), ( 20, 3), ( 20, 4), ( 20, 5), ( 20, 6), ( 20, 7), ( 20, 8), ( 20, 9), ( 20,10), ( 20,11), ( 20,12), ( 20,13), ( 20,14), ( 20,15)
( 21, 0), ( 21, 1), ( 21, 2), ( 21, 3), ( 21, 4), ( 21, 5), ( 21, 6), ( 21, 7), ( 21, 8), ( 21, 9), ( 21,10), ( 21,11), ( 21,12), ( 21,13), ( 21,14), ( 21,15)
( 22, 0), ( 22, 1), ( 22, 2), ( 22, 3), ( 22, 4), ( 22, 5), ( 22, 6), ( 22, 7), ( 22, 8), ( 22, 9), ( 22,10), ( 22,11), ( 22,12), ( 22,13), ( 22,14), ( 22,15)
( 23, 0), ( 23, 1), ( 23, 2), ( 23, 3), ( 23, 4), ( 23, 5), ( 23, 6), ( 23, 7), ( 23, 8), ( 23, 9), ( 23,10), ( 23,11), ( 23,12), ( 23,13), ( 23,14), ( 23,15)
( 24, 0), ( 24, 1), ( 24, 2), ( 24, 3), ( 24, 4), ( 24, 5), ( 24, 6), ( 24, 7), ( 24, 8), ( 24, 9), ( 24,10), ( 24,11), ( 24,12), ( 24,13), ( 24,14), ( 24,15)
( 25, 0), ( 25, 1), ( 25, 2), ( 25, 3), ( 25, 4), ( 25, 5), ( 25, 6), ( 25, 7), ( 25, 8), ( 25, 9), ( 25,10), ( 25,11), ( 25,12), ( 25,13), ( 25,14), ( 25,15)
( 26, 0), ( 26, 1), ( 26, 2), ( 26, 3), ( 26, 4), ( 26, 5), ( 26, 6), ( 26, 7), ( 26, 8), ( 26, 9), ( 26,10), ( 26,11), ( 26,12), ( 26,13), ( 26,14), ( 26,15)
( 27, 0), ( 27, 1), ( 27, 2), ( 27, 3), ( 27, 4), ( 27, 5), ( 27, 6), ( 27, 7), ( 27, 8), ( 27, 9), ( 27,10), ( 27,11), ( 27,12), ( 27,13), ( 27,14), ( 27,15)
( 28, 0), ( 28, 1), ( 28, 2), ( 28, 3), ( 28, 4), ( 28, 5), ( 28, 6), ( 28, 7), ( 28, 8), ( 28, 9), ( 28,10), ( 28,11), ( 28,12), ( 28,13), ( 28,14), ( 28,15)
( 29, 0), ( 29, 1), ( 29, 2), ( 29, 3), ( 29, 4), ( 29, 5), ( 29, 6), ( 29, 7), ( 29, 8), ( 29, 9), ( 29,10), ( 29,11), ( 29,12), ( 29,13), ( 29,14), ( 29,15)
( 30, 0), ( 30, 1), ( 30, 2), ( 30, 3), ( 30, 4), ( 30, 5), ( 30, 6), ( 30, 7), ( 30, 8), ( 30, 9), ( 30,10), ( 30,11), ( 30,12), ( 30,13), ( 30,14), ( 30,15)
( 31, 0), ( 31, 1), ( 31, 2), ( 31, 3), ( 31, 4), ( 31, 5), ( 31, 6), ( 31, 7), ( 31, 8), ( 31, 9), ( 31,10), ( 31,11), ( 31,12), ( 31,13), ( 31,14), ( 31,15)
(  0,16), (  0,17), (  0,18), (  0,19), (  0,20), (  0,21), (  0,22), (  0,23), (  0,24), (  0,25), (  0,26), (  0,27), (  0,28), (  0,29), (  0,30), (  0,31)
(  1,16), (  1,17), (  1,18), (  1,19), (  1,20), (  1,21), (  1,22), (  1,23), (  1,24), (  1,25), (  1,26), (  1,27), (  1,28), (  1,29), (  1,30), (  1,31)
(  2,16), (  2,17), (  2,18), (  2,19), (  2,20), (  2,21), (  2,22), (  2,23), (  2,24), (  2,25), (  2,26), (  2,27), (  2,28), (  2,29), (  2,30), (  2,31)
(  3,16), (  3,17), (  3,18), (  3,19), (  3,20), (  3,21), (  3,22), (  3,23), (  3,24), (  3,25), (  3,26), (  3,27), (  3,28), (  3,29), (  3,30), (  3,31)
(  4,16), (  4,17), (  4,18), (  4,19), (  4,20), (  4,21), (  4,22), (  4,23), (  4,24), (  4,25), (  4,26), (  4,27), (  4,28), (  4,29), (  4,30), (  4,31)
(  5,16), (  5,17), (  5,18), (  5,19), (  5,20), (  5,21), (  5,22), (  5,23), (  5,24), (  5,25), (  5,26), (  5,27), (  5,28), (  5,29), (  5,30), (  5,31)
(  6,16), (  6,17), (  6,18), (  6,19), (  6,20), (  6,21), (  6,22), (  6,23), (  6,24), (  6,25), (  6,26), (  6,27), (  6,28), (  6,29), (  6,30), (  6,31)
(  7,16), (  7,17), (  7,18), (  7,19), (  7,20), (  7,21), (  7,22), (  7,23), (  7,24), (  7,25), (  7,26), (  7,27), (  7,28), (  7,29), (  7,30), (  7,31)
(  8,16), (  8,17), (  8,18), (  8,19), (  8,20), (  8,21), (  8,22), (  8,23), (  8,24), (  8,25), (  8,26), (  8,27), (  8,28), (  8,29), (  8,30), (  8,31)
(  9,16), (  9,17), (  9,18), (  9,19), (  9,20), (  9,21), (  9,22), (  9,23), (  9,24), (  9,25), (  9,26), (  9,27), (  9,28), (  9,29), (  9,30), (  9,31)
( 10,16), ( 10,17), ( 10,18), ( 10,19), ( 10,20), ( 10,21), ( 10,22), ( 10,23), ( 10,24), ( 10,25), ( 10,26), ( 10,27), ( 10,28), ( 10,29), ( 10,30), ( 10,31)
( 11,16), ( 11,17), ( 11,18), ( 11,19), ( 11,20), ( 11,21), ( 11,22), ( 11,23), ( 11,24), ( 11,25), ( 11,26), ( 11,27), ( 11,28), ( 11,29), ( 11,30), ( 11,31)
( 12,16), ( 12,17), ( 12,18), ( 12,19), ( 12,20), ( 12,21), ( 12,22), ( 12,23), ( 12,24), ( 12,25), ( 12,26), ( 12,27), ( 12,28), ( 12,29), ( 12,30), ( 12,31)
( 13,16), ( 13,17), ( 13,18), ( 13,19), ( 13,20), ( 13,21), ( 13,22), ( 13,23), ( 13,24), ( 13,25), ( 13,26), ( 13,27), ( 13,28), ( 13,29), ( 13,30), ( 13,31)
( 14,16), ( 14,17), ( 14,18), ( 14,19), ( 14,20), ( 14,21), ( 14,22), ( 14,23), ( 14,24), ( 14,25), ( 14,26), ( 14,27), ( 14,28), ( 14,29), ( 14,30), ( 14,31)
( 15,16), ( 15,17), ( 15,18), ( 15,19), ( 15,20), ( 15,21), ( 15,22), ( 15,23), ( 15,24), ( 15,25), ( 15,26), ( 15,27), ( 15,28), ( 15,29), ( 15,30), ( 15,31)
( 16,16), ( 16,17), ( 16,18), ( 16,19), ( 16,20), ( 16,21), ( 16,22), ( 16,23), ( 16,24), ( 16,25), ( 16,26), ( 16,27), ( 16,28), ( 16,29), ( 16,30), ( 16,31)
( 17,16), ( 17,17), ( 17,18), ( 17,19), ( 17,20), ( 17,21), ( 17,22), ( 17,23), ( 17,24), ( 17,25), ( 17,26), ( 17,27), ( 17,28), ( 17,29), ( 17,30), ( 17,31)
( 18,16), ( 18,17), ( 18,18), ( 18,19), ( 18,20), ( 18,21), ( 18,22), ( 18,23), ( 18,24), ( 18,25), ( 18,26), ( 18,27), ( 18,28), ( 18,29), ( 18,30), ( 18,31)
( 19,16), ( 19,17), ( 19,18), ( 19,19), ( 19,20), ( 19,21), ( 19,22), ( 19,23), ( 19,24), ( 19,25), ( 19,26), ( 19,27), ( 19,28), ( 19,29), ( 19,30), ( 19,31)
( 20,16), ( 20,17), ( 20,18), ( 20,19), ( 20,20), ( 20,21), ( 20,22), ( 20,23), ( 20,24), ( 20,25), ( 20,26), ( 20,27), ( 20,28), ( 20,29), ( 20,30), ( 20,31)
( 21,16), ( 21,17), ( 21,18), ( 21,19), ( 21,20), ( 21,21), ( 21,22), ( 21,23), ( 21,24), ( 21,25), ( 21,26), ( 21,27), ( 21,28), ( 21,29), ( 21,30), ( 21,31)
( 22,16), ( 22,17), ( 22,18), ( 22,19), ( 22,20), ( 22,21), ( 22,22), ( 22,23), ( 22,24), ( 22,25), ( 22,26), ( 22,27), ( 22,28), ( 22,29), ( 22,30), ( 22,31)
( 23,16), ( 23,17), ( 23,18), ( 23,19), ( 23,20), ( 23,21), ( 23,22), ( 23,23), ( 23,24), ( 23,25), ( 23,26), ( 23,27), ( 23,28), ( 23,29), ( 23,30), ( 23,31)
( 24,16), ( 24,17), ( 24,18), ( 24,19), ( 24,20), ( 24,21), ( 24,22), ( 24,23), ( 24,24), ( 24,25), ( 24,26), ( 24,27), ( 24,28), ( 24,29), ( 24,30), ( 24,31)
( 25,16), ( 25,17), ( 25,18), ( 25,19), ( 25,20), ( 25,21), ( 25,22), ( 25,23), ( 25,24), ( 25,25), ( 25,26), ( 25,27), ( 25,28), ( 25,29), ( 25,30), ( 25,31)
( 26,16), ( 26,17), ( 26,18), ( 26,19), ( 26,20), ( 26,21), ( 26,22), ( 26,23), ( 26,24), ( 26,25), ( 26,26), ( 26,27), ( 26,28), ( 26,29), ( 26,30), ( 26,31)
( 27,16), ( 27,17), ( 27,18), ( 27,19), ( 27,20), ( 27,21), ( 27,22), ( 27,23), ( 27,24), ( 27,25), ( 27,26), ( 27,27), ( 27,28), ( 27,29), ( 27,30), ( 27,31)
( 28,16), ( 28,17), ( 28,18), ( 28,19), ( 28,20), ( 28,21), ( 28,22), ( 28,23), ( 28,24), ( 28,25), ( 28,26), ( 28,27), ( 28,28), ( 28,29), ( 28,30), ( 28,31)
( 29,16), ( 29,17), ( 29,18), ( 29,19), ( 29,20), ( 29,21), ( 29,22), ( 29,23), ( 29,24), ( 29,25), ( 29,26), ( 29,27), ( 29,28), ( 29,29), ( 29,30), ( 29,31)
( 30,16), ( 30,17), ( 30,18), ( 30,19), ( 30,20), ( 30,21), ( 30,22), ( 30,23), ( 30,24), ( 30,25), ( 30,26), ( 30,27), ( 30,28), ( 30,29), ( 30,30), ( 30,31)
( 31,16), ( 31,17), ( 31,18), ( 31,19), ( 31,20), ( 31,21), ( 31,22), ( 31,23), ( 31,24), ( 31,25), ( 31,26), ( 31,27), ( 31,28), ( 31,29), ( 31,30), ( 31,31)

Because we are only interested in a subgroup (warp), the other warps in the layout are omitted.

For a bf16 x bf16 DPAS (with fp32 accumulator) the A matrix input to each instruction is expected to be 8x16 with the following subgroup layout:

WI 0 WI 1 WI 2 WI 3 WI 4 WI 5 WI 6 WI 7 WI 8 WI 9 WI 10 WI 11 WI 12 WI 13 WI 14 WI 15
0, 0 0, 1 0, 2 0, 3 0, 4 0, 5 0, 6 0, 7 0, 8 0, 9 0, 10 0, 11 0, 12 0, 13 0, 14 0, 15
1, 0 1, 1 1, 2 1, 3 1, 4 1, 5 1, 6 1, 7 1, 8 1, 9 1, 10 1, 11 1, 12 1, 13 1, 14 1, 15
2, 0 2, 1 2, 2 2, 3 2, 4 2, 5 2, 6 2, 7 2, 8 2, 9 2, 10 2, 11 2, 12 2, 13 2, 14 2, 15
3, 0 3, 1 3, 2 3, 3 3, 4 3, 5 3, 6 3, 7 3, 8 3, 9 3, 10 3, 11 3, 12 3, 13 3, 14 3, 15
4, 0 4, 1 4, 2 4, 3 4, 4 4, 5 4, 6 4, 7 4, 8 4, 9 4, 10 4, 11 4, 12 4, 13 4, 14 4, 15
5, 0 5, 1 5, 2 5, 3 5, 4 5, 5 5, 6 5, 7 5, 8 5, 9 5, 10 5, 11 5, 12 5, 13 5, 14 5, 15
6, 0 6, 1 6, 2 6, 3 6, 4 6, 5 6, 6 6, 7 6, 8 6, 9 6, 10 6, 11 6, 12 6, 13 6, 14 6, 15
7, 0 7, 1 7, 2 7, 3 7, 4 7, 5 7, 6 7, 7 7, 8 7, 9 7, 10 7, 11 7, 12 7, 13 7, 14 7, 15

We now have all the information we need to determine both the number of 2D block loads for the A matrix and their sizes.

Up to this point the load for the A matrix is represented as a single load per block with the TTGIR layouts mapping that block to underlying hardware primitives (workgroups and subgroups/warps). We need to translate that load into one or more load instructions per subgroup (warp). We start with the DPAS tile size, which as shown above is 8x16. We create a linear layout which takes as input a single dimension, offset, which is a 1D-rowwise index into the tile, and outputs the 2D tensor coordinate corresponding to that index:

Block load tile layout:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (0, 8)
   offset=16 -> (1, 0)
   offset=32 -> (2, 0)
   offset=64 -> (4, 0)
where out dims are: [dim0 (size 8), dim1 (size 16)]

The indices for the first SIMD lane / work-item (every 16th offset) are printed below:

0 : 0, 0
16 : 1, 0
32 : 2, 0
48 : 3, 0
64 : 4, 0
80 : 5, 0
96 : 6, 0
112 : 7, 0

During lowering of the tt.dot operation to DPAS multiple DPAS instructions will be generated according to the TTGIR DPAS layout. We need to enlarge the 2D block load tile we created above to load as much data as possible for the DPAS instructions in the subgroup. We do this by adding additional input dimensions to the layout, starting with iteration. Each iteration corresponds to a DPAS instruction. Each DPAS instruction operates on a DPAS tile. Specifically, for each iteration we will generate a shuffle vector per work-item which will output the registers in the correct order for DPAS. The number of iterations is determined by taking the maximum contiguous tile size for the DPAS instructions in the subgroup, subject to hardware limitations. After computing iterations we know the maximum tile size we can load in a single 2D block load instruction. We can then compute the number of required loads. For the GEMM kernel A matrix we have 4 iterations across the outer dimension and 2 iterations across the inner dimension. After adding iterations our load has increased in size from the DPAS tile size (8x16) to (32x32).

Block load tile layout after adding iterations:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (0, 8)
   offset=16 -> (1, 0)
   offset=32 -> (2, 0)
   offset=64 -> (4, 0)
 - iteration=1 -> (8, 0)
   iteration=2 -> (16, 0)
   iteration=4 -> (0, 16)
where out dims are: [dim0 (size 32), dim1 (size 32)]

The DPAS layout is replicated first in the outer dimension, then in the inner dimension. Referring to the DPAS layout we can see that the first contiguous block processed by DPAS starts at (0,0) and ends at (31,15). We can see that the block load layout behaves similarly. Iteration 3 ends at (31,15) and Iteration 4 starts at (0,16).

0, 0 : 0, 0
0, 127 : 7, 15
1, 0 : 8, 0
1, 127 : 15, 15
2, 0 : 16, 0
2, 127 : 23, 15
3, 0 : 24, 0
3, 127 : 31, 15
4, 0 : 0, 16
4, 127 : 7, 31
5, 0 : 8, 16
5, 127 : 15, 31
6, 0 : 16, 16
6, 127 : 23, 31
7, 0 : 24, 16
7, 127 : 31, 31

In some cases we may need multiple block loads; e.g. if the hardware restrictions on load size are exceeded or if we have non-contiguous replications in the layout. For the A matrix there is only one load:

Block load tile layout after adding loads:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (0, 8)
   offset=16 -> (1, 0)
   offset=32 -> (2, 0)
   offset=64 -> (4, 0)
 - iteration=1 -> (8, 0)
   iteration=2 -> (16, 0)
   iteration=4 -> (0, 16)
 - load is a size 1 dimension
where out dims are: [dim0 (size 32), dim1 (size 32)]

We now have a linear layout which maps DPAS instructions to locations in the 2D load output. We can use this layout to determine the global offset for each load, to compute the local load offsets for each DPAS instruction, and to compute the shuffle vectors which translate the output of the load into the required layout for DPAS.

B Matrix

We will use the triton-tensor-layout utility to print the DPAS layouts for the B operand:

tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>, kWidth = 2}>>

Warp0:
( 0,  0), ( 0,  1), ( 0,  2), ( 0,  3), ( 0,  4), ( 0,  5), ( 0,  6), ( 0,  7), ( 0,  8), ( 0,  9), ( 0, 10), ( 0, 11), ( 0, 12), ( 0, 13), ( 0, 14), ( 0, 15)
( 1,  0), ( 1,  1), ( 1,  2), ( 1,  3), ( 1,  4), ( 1,  5), ( 1,  6), ( 1,  7), ( 1,  8), ( 1,  9), ( 1, 10), ( 1, 11), ( 1, 12), ( 1, 13), ( 1, 14), ( 1, 15)
( 2,  0), ( 2,  1), ( 2,  2), ( 2,  3), ( 2,  4), ( 2,  5), ( 2,  6), ( 2,  7), ( 2,  8), ( 2,  9), ( 2, 10), ( 2, 11), ( 2, 12), ( 2, 13), ( 2, 14), ( 2, 15)
( 3,  0), ( 3,  1), ( 3,  2), ( 3,  3), ( 3,  4), ( 3,  5), ( 3,  6), ( 3,  7), ( 3,  8), ( 3,  9), ( 3, 10), ( 3, 11), ( 3, 12), ( 3, 13), ( 3, 14), ( 3, 15)
( 4,  0), ( 4,  1), ( 4,  2), ( 4,  3), ( 4,  4), ( 4,  5), ( 4,  6), ( 4,  7), ( 4,  8), ( 4,  9), ( 4, 10), ( 4, 11), ( 4, 12), ( 4, 13), ( 4, 14), ( 4, 15)
( 5,  0), ( 5,  1), ( 5,  2), ( 5,  3), ( 5,  4), ( 5,  5), ( 5,  6), ( 5,  7), ( 5,  8), ( 5,  9), ( 5, 10), ( 5, 11), ( 5, 12), ( 5, 13), ( 5, 14), ( 5, 15)
( 6,  0), ( 6,  1), ( 6,  2), ( 6,  3), ( 6,  4), ( 6,  5), ( 6,  6), ( 6,  7), ( 6,  8), ( 6,  9), ( 6, 10), ( 6, 11), ( 6, 12), ( 6, 13), ( 6, 14), ( 6, 15)
( 7,  0), ( 7,  1), ( 7,  2), ( 7,  3), ( 7,  4), ( 7,  5), ( 7,  6), ( 7,  7), ( 7,  8), ( 7,  9), ( 7, 10), ( 7, 11), ( 7, 12), ( 7, 13), ( 7, 14), ( 7, 15)
( 8,  0), ( 8,  1), ( 8,  2), ( 8,  3), ( 8,  4), ( 8,  5), ( 8,  6), ( 8,  7), ( 8,  8), ( 8,  9), ( 8, 10), ( 8, 11), ( 8, 12), ( 8, 13), ( 8, 14), ( 8, 15)
( 9,  0), ( 9,  1), ( 9,  2), ( 9,  3), ( 9,  4), ( 9,  5), ( 9,  6), ( 9,  7), ( 9,  8), ( 9,  9), ( 9, 10), ( 9, 11), ( 9, 12), ( 9, 13), ( 9, 14), ( 9, 15)
(10,  0), (10,  1), (10,  2), (10,  3), (10,  4), (10,  5), (10,  6), (10,  7), (10,  8), (10,  9), (10, 10), (10, 11), (10, 12), (10, 13), (10, 14), (10, 15)
(11,  0), (11,  1), (11,  2), (11,  3), (11,  4), (11,  5), (11,  6), (11,  7), (11,  8), (11,  9), (11, 10), (11, 11), (11, 12), (11, 13), (11, 14), (11, 15)
(12,  0), (12,  1), (12,  2), (12,  3), (12,  4), (12,  5), (12,  6), (12,  7), (12,  8), (12,  9), (12, 10), (12, 11), (12, 12), (12, 13), (12, 14), (12, 15)
(13,  0), (13,  1), (13,  2), (13,  3), (13,  4), (13,  5), (13,  6), (13,  7), (13,  8), (13,  9), (13, 10), (13, 11), (13, 12), (13, 13), (13, 14), (13, 15)
(14,  0), (14,  1), (14,  2), (14,  3), (14,  4), (14,  5), (14,  6), (14,  7), (14,  8), (14,  9), (14, 10), (14, 11), (14, 12), (14, 13), (14, 14), (14, 15)
(15,  0), (15,  1), (15,  2), (15,  3), (15,  4), (15,  5), (15,  6), (15,  7), (15,  8), (15,  9), (15, 10), (15, 11), (15, 12), (15, 13), (15, 14), (15, 15)
( 0, 16), ( 0, 17), ( 0, 18), ( 0, 19), ( 0, 20), ( 0, 21), ( 0, 22), ( 0, 23), ( 0, 24), ( 0, 25), ( 0, 26), ( 0, 27), ( 0, 28), ( 0, 29), ( 0, 30), ( 0, 31)
( 1, 16), ( 1, 17), ( 1, 18), ( 1, 19), ( 1, 20), ( 1, 21), ( 1, 22), ( 1, 23), ( 1, 24), ( 1, 25), ( 1, 26), ( 1, 27), ( 1, 28), ( 1, 29), ( 1, 30), ( 1, 31)
( 2, 16), ( 2, 17), ( 2, 18), ( 2, 19), ( 2, 20), ( 2, 21), ( 2, 22), ( 2, 23), ( 2, 24), ( 2, 25), ( 2, 26), ( 2, 27), ( 2, 28), ( 2, 29), ( 2, 30), ( 2, 31)
( 3, 16), ( 3, 17), ( 3, 18), ( 3, 19), ( 3, 20), ( 3, 21), ( 3, 22), ( 3, 23), ( 3, 24), ( 3, 25), ( 3, 26), ( 3, 27), ( 3, 28), ( 3, 29), ( 3, 30), ( 3, 31)
( 4, 16), ( 4, 17), ( 4, 18), ( 4, 19), ( 4, 20), ( 4, 21), ( 4, 22), ( 4, 23), ( 4, 24), ( 4, 25), ( 4, 26), ( 4, 27), ( 4, 28), ( 4, 29), ( 4, 30), ( 4, 31)
( 5, 16), ( 5, 17), ( 5, 18), ( 5, 19), ( 5, 20), ( 5, 21), ( 5, 22), ( 5, 23), ( 5, 24), ( 5, 25), ( 5, 26), ( 5, 27), ( 5, 28), ( 5, 29), ( 5, 30), ( 5, 31)
( 6, 16), ( 6, 17), ( 6, 18), ( 6, 19), ( 6, 20), ( 6, 21), ( 6, 22), ( 6, 23), ( 6, 24), ( 6, 25), ( 6, 26), ( 6, 27), ( 6, 28), ( 6, 29), ( 6, 30), ( 6, 31)
( 7, 16), ( 7, 17), ( 7, 18), ( 7, 19), ( 7, 20), ( 7, 21), ( 7, 22), ( 7, 23), ( 7, 24), ( 7, 25), ( 7, 26), ( 7, 27), ( 7, 28), ( 7, 29), ( 7, 30), ( 7, 31)
( 8, 16), ( 8, 17), ( 8, 18), ( 8, 19), ( 8, 20), ( 8, 21), ( 8, 22), ( 8, 23), ( 8, 24), ( 8, 25), ( 8, 26), ( 8, 27), ( 8, 28), ( 8, 29), ( 8, 30), ( 8, 31)
( 9, 16), ( 9, 17), ( 9, 18), ( 9, 19), ( 9, 20), ( 9, 21), ( 9, 22), ( 9, 23), ( 9, 24), ( 9, 25), ( 9, 26), ( 9, 27), ( 9, 28), ( 9, 29), ( 9, 30), ( 9, 31)
(10, 16), (10, 17), (10, 18), (10, 19), (10, 20), (10, 21), (10, 22), (10, 23), (10, 24), (10, 25), (10, 26), (10, 27), (10, 28), (10, 29), (10, 30), (10, 31)
(11, 16), (11, 17), (11, 18), (11, 19), (11, 20), (11, 21), (11, 22), (11, 23), (11, 24), (11, 25), (11, 26), (11, 27), (11, 28), (11, 29), (11, 30), (11, 31)
(12, 16), (12, 17), (12, 18), (12, 19), (12, 20), (12, 21), (12, 22), (12, 23), (12, 24), (12, 25), (12, 26), (12, 27), (12, 28), (12, 29), (12, 30), (12, 31)
(13, 16), (13, 17), (13, 18), (13, 19), (13, 20), (13, 21), (13, 22), (13, 23), (13, 24), (13, 25), (13, 26), (13, 27), (13, 28), (13, 29), (13, 30), (13, 31)
(14, 16), (14, 17), (14, 18), (14, 19), (14, 20), (14, 21), (14, 22), (14, 23), (14, 24), (14, 25), (14, 26), (14, 27), (14, 28), (14, 29), (14, 30), (14, 31)
(15, 16), (15, 17), (15, 18), (15, 19), (15, 20), (15, 21), (15, 22), (15, 23), (15, 24), (15, 25), (15, 26), (15, 27), (15, 28), (15, 29), (15, 30), (15, 31)
(16,  0), (16,  1), (16,  2), (16,  3), (16,  4), (16,  5), (16,  6), (16,  7), (16,  8), (16,  9), (16, 10), (16, 11), (16, 12), (16, 13), (16, 14), (16, 15)
(17,  0), (17,  1), (17,  2), (17,  3), (17,  4), (17,  5), (17,  6), (17,  7), (17,  8), (17,  9), (17, 10), (17, 11), (17, 12), (17, 13), (17, 14), (17, 15)
(18,  0), (18,  1), (18,  2), (18,  3), (18,  4), (18,  5), (18,  6), (18,  7), (18,  8), (18,  9), (18, 10), (18, 11), (18, 12), (18, 13), (18, 14), (18, 15)
(19,  0), (19,  1), (19,  2), (19,  3), (19,  4), (19,  5), (19,  6), (19,  7), (19,  8), (19,  9), (19, 10), (19, 11), (19, 12), (19, 13), (19, 14), (19, 15)
(20,  0), (20,  1), (20,  2), (20,  3), (20,  4), (20,  5), (20,  6), (20,  7), (20,  8), (20,  9), (20, 10), (20, 11), (20, 12), (20, 13), (20, 14), (20, 15)
(21,  0), (21,  1), (21,  2), (21,  3), (21,  4), (21,  5), (21,  6), (21,  7), (21,  8), (21,  9), (21, 10), (21, 11), (21, 12), (21, 13), (21, 14), (21, 15)
(22,  0), (22,  1), (22,  2), (22,  3), (22,  4), (22,  5), (22,  6), (22,  7), (22,  8), (22,  9), (22, 10), (22, 11), (22, 12), (22, 13), (22, 14), (22, 15)
(23,  0), (23,  1), (23,  2), (23,  3), (23,  4), (23,  5), (23,  6), (23,  7), (23,  8), (23,  9), (23, 10), (23, 11), (23, 12), (23, 13), (23, 14), (23, 15)
(24,  0), (24,  1), (24,  2), (24,  3), (24,  4), (24,  5), (24,  6), (24,  7), (24,  8), (24,  9), (24, 10), (24, 11), (24, 12), (24, 13), (24, 14), (24, 15)
(25,  0), (25,  1), (25,  2), (25,  3), (25,  4), (25,  5), (25,  6), (25,  7), (25,  8), (25,  9), (25, 10), (25, 11), (25, 12), (25, 13), (25, 14), (25, 15)
(26,  0), (26,  1), (26,  2), (26,  3), (26,  4), (26,  5), (26,  6), (26,  7), (26,  8), (26,  9), (26, 10), (26, 11), (26, 12), (26, 13), (26, 14), (26, 15)
(27,  0), (27,  1), (27,  2), (27,  3), (27,  4), (27,  5), (27,  6), (27,  7), (27,  8), (27,  9), (27, 10), (27, 11), (27, 12), (27, 13), (27, 14), (27, 15)
(28,  0), (28,  1), (28,  2), (28,  3), (28,  4), (28,  5), (28,  6), (28,  7), (28,  8), (28,  9), (28, 10), (28, 11), (28, 12), (28, 13), (28, 14), (28, 15)
(29,  0), (29,  1), (29,  2), (29,  3), (29,  4), (29,  5), (29,  6), (29,  7), (29,  8), (29,  9), (29, 10), (29, 11), (29, 12), (29, 13), (29, 14), (29, 15)
(30,  0), (30,  1), (30,  2), (30,  3), (30,  4), (30,  5), (30,  6), (30,  7), (30,  8), (30,  9), (30, 10), (30, 11), (30, 12), (30, 13), (30, 14), (30, 15)
(31,  0), (31,  1), (31,  2), (31,  3), (31,  4), (31,  5), (31,  6), (31,  7), (31,  8), (31,  9), (31, 10), (31, 11), (31, 12), (31, 13), (31, 14), (31, 15)
(16, 16), (16, 17), (16, 18), (16, 19), (16, 20), (16, 21), (16, 22), (16, 23), (16, 24), (16, 25), (16, 26), (16, 27), (16, 28), (16, 29), (16, 30), (16, 31)
(17, 16), (17, 17), (17, 18), (17, 19), (17, 20), (17, 21), (17, 22), (17, 23), (17, 24), (17, 25), (17, 26), (17, 27), (17, 28), (17, 29), (17, 30), (17, 31)
(18, 16), (18, 17), (18, 18), (18, 19), (18, 20), (18, 21), (18, 22), (18, 23), (18, 24), (18, 25), (18, 26), (18, 27), (18, 28), (18, 29), (18, 30), (18, 31)
(19, 16), (19, 17), (19, 18), (19, 19), (19, 20), (19, 21), (19, 22), (19, 23), (19, 24), (19, 25), (19, 26), (19, 27), (19, 28), (19, 29), (19, 30), (19, 31)
(20, 16), (20, 17), (20, 18), (20, 19), (20, 20), (20, 21), (20, 22), (20, 23), (20, 24), (20, 25), (20, 26), (20, 27), (20, 28), (20, 29), (20, 30), (20, 31)
(21, 16), (21, 17), (21, 18), (21, 19), (21, 20), (21, 21), (21, 22), (21, 23), (21, 24), (21, 25), (21, 26), (21, 27), (21, 28), (21, 29), (21, 30), (21, 31)
(22, 16), (22, 17), (22, 18), (22, 19), (22, 20), (22, 21), (22, 22), (22, 23), (22, 24), (22, 25), (22, 26), (22, 27), (22, 28), (22, 29), (22, 30), (22, 31)
(23, 16), (23, 17), (23, 18), (23, 19), (23, 20), (23, 21), (23, 22), (23, 23), (23, 24), (23, 25), (23, 26), (23, 27), (23, 28), (23, 29), (23, 30), (23, 31)
(24, 16), (24, 17), (24, 18), (24, 19), (24, 20), (24, 21), (24, 22), (24, 23), (24, 24), (24, 25), (24, 26), (24, 27), (24, 28), (24, 29), (24, 30), (24, 31)
(25, 16), (25, 17), (25, 18), (25, 19), (25, 20), (25, 21), (25, 22), (25, 23), (25, 24), (25, 25), (25, 26), (25, 27), (25, 28), (25, 29), (25, 30), (25, 31)
(26, 16), (26, 17), (26, 18), (26, 19), (26, 20), (26, 21), (26, 22), (26, 23), (26, 24), (26, 25), (26, 26), (26, 27), (26, 28), (26, 29), (26, 30), (26, 31)
(27, 16), (27, 17), (27, 18), (27, 19), (27, 20), (27, 21), (27, 22), (27, 23), (27, 24), (27, 25), (27, 26), (27, 27), (27, 28), (27, 29), (27, 30), (27, 31)
(28, 16), (28, 17), (28, 18), (28, 19), (28, 20), (28, 21), (28, 22), (28, 23), (28, 24), (28, 25), (28, 26), (28, 27), (28, 28), (28, 29), (28, 30), (28, 31)
(29, 16), (29, 17), (29, 18), (29, 19), (29, 20), (29, 21), (29, 22), (29, 23), (29, 24), (29, 25), (29, 26), (29, 27), (29, 28), (29, 29), (29, 30), (29, 31)
(30, 16), (30, 17), (30, 18), (30, 19), (30, 20), (30, 21), (30, 22), (30, 23), (30, 24), (30, 25), (30, 26), (30, 27), (30, 28), (30, 29), (30, 30), (30, 31)
(31, 16), (31, 17), (31, 18), (31, 19), (31, 20), (31, 21), (31, 22), (31, 23), (31, 24), (31, 25), (31, 26), (31, 27), (31, 28), (31, 29), (31, 30), (31, 31)
( 0,128), ( 0,129), ( 0,130), ( 0,131), ( 0,132), ( 0,133), ( 0,134), ( 0,135), ( 0,136), ( 0,137), ( 0,138), ( 0,139), ( 0,140), ( 0,141), ( 0,142), ( 0,143)
( 1,128), ( 1,129), ( 1,130), ( 1,131), ( 1,132), ( 1,133), ( 1,134), ( 1,135), ( 1,136), ( 1,137), ( 1,138), ( 1,139), ( 1,140), ( 1,141), ( 1,142), ( 1,143)
( 2,128), ( 2,129), ( 2,130), ( 2,131), ( 2,132), ( 2,133), ( 2,134), ( 2,135), ( 2,136), ( 2,137), ( 2,138), ( 2,139), ( 2,140), ( 2,141), ( 2,142), ( 2,143)
( 3,128), ( 3,129), ( 3,130), ( 3,131), ( 3,132), ( 3,133), ( 3,134), ( 3,135), ( 3,136), ( 3,137), ( 3,138), ( 3,139), ( 3,140), ( 3,141), ( 3,142), ( 3,143)
( 4,128), ( 4,129), ( 4,130), ( 4,131), ( 4,132), ( 4,133), ( 4,134), ( 4,135), ( 4,136), ( 4,137), ( 4,138), ( 4,139), ( 4,140), ( 4,141), ( 4,142), ( 4,143)
( 5,128), ( 5,129), ( 5,130), ( 5,131), ( 5,132), ( 5,133), ( 5,134), ( 5,135), ( 5,136), ( 5,137), ( 5,138), ( 5,139), ( 5,140), ( 5,141), ( 5,142), ( 5,143)
( 6,128), ( 6,129), ( 6,130), ( 6,131), ( 6,132), ( 6,133), ( 6,134), ( 6,135), ( 6,136), ( 6,137), ( 6,138), ( 6,139), ( 6,140), ( 6,141), ( 6,142), ( 6,143)
( 7,128), ( 7,129), ( 7,130), ( 7,131), ( 7,132), ( 7,133), ( 7,134), ( 7,135), ( 7,136), ( 7,137), ( 7,138), ( 7,139), ( 7,140), ( 7,141), ( 7,142), ( 7,143)
( 8,128), ( 8,129), ( 8,130), ( 8,131), ( 8,132), ( 8,133), ( 8,134), ( 8,135), ( 8,136), ( 8,137), ( 8,138), ( 8,139), ( 8,140), ( 8,141), ( 8,142), ( 8,143)
( 9,128), ( 9,129), ( 9,130), ( 9,131), ( 9,132), ( 9,133), ( 9,134), ( 9,135), ( 9,136), ( 9,137), ( 9,138), ( 9,139), ( 9,140), ( 9,141), ( 9,142), ( 9,143)
(10,128), (10,129), (10,130), (10,131), (10,132), (10,133), (10,134), (10,135), (10,136), (10,137), (10,138), (10,139), (10,140), (10,141), (10,142), (10,143)
(11,128), (11,129), (11,130), (11,131), (11,132), (11,133), (11,134), (11,135), (11,136), (11,137), (11,138), (11,139), (11,140), (11,141), (11,142), (11,143)
(12,128), (12,129), (12,130), (12,131), (12,132), (12,133), (12,134), (12,135), (12,136), (12,137), (12,138), (12,139), (12,140), (12,141), (12,142), (12,143)
(13,128), (13,129), (13,130), (13,131), (13,132), (13,133), (13,134), (13,135), (13,136), (13,137), (13,138), (13,139), (13,140), (13,141), (13,142), (13,143)
(14,128), (14,129), (14,130), (14,131), (14,132), (14,133), (14,134), (14,135), (14,136), (14,137), (14,138), (14,139), (14,140), (14,141), (14,142), (14,143)
(15,128), (15,129), (15,130), (15,131), (15,132), (15,133), (15,134), (15,135), (15,136), (15,137), (15,138), (15,139), (15,140), (15,141), (15,142), (15,143)
( 0,144), ( 0,145), ( 0,146), ( 0,147), ( 0,148), ( 0,149), ( 0,150), ( 0,151), ( 0,152), ( 0,153), ( 0,154), ( 0,155), ( 0,156), ( 0,157), ( 0,158), ( 0,159)
( 1,144), ( 1,145), ( 1,146), ( 1,147), ( 1,148), ( 1,149), ( 1,150), ( 1,151), ( 1,152), ( 1,153), ( 1,154), ( 1,155), ( 1,156), ( 1,157), ( 1,158), ( 1,159)
( 2,144), ( 2,145), ( 2,146), ( 2,147), ( 2,148), ( 2,149), ( 2,150), ( 2,151), ( 2,152), ( 2,153), ( 2,154), ( 2,155), ( 2,156), ( 2,157), ( 2,158), ( 2,159)
( 3,144), ( 3,145), ( 3,146), ( 3,147), ( 3,148), ( 3,149), ( 3,150), ( 3,151), ( 3,152), ( 3,153), ( 3,154), ( 3,155), ( 3,156), ( 3,157), ( 3,158), ( 3,159)
( 4,144), ( 4,145), ( 4,146), ( 4,147), ( 4,148), ( 4,149), ( 4,150), ( 4,151), ( 4,152), ( 4,153), ( 4,154), ( 4,155), ( 4,156), ( 4,157), ( 4,158), ( 4,159)
( 5,144), ( 5,145), ( 5,146), ( 5,147), ( 5,148), ( 5,149), ( 5,150), ( 5,151), ( 5,152), ( 5,153), ( 5,154), ( 5,155), ( 5,156), ( 5,157), ( 5,158), ( 5,159)
( 6,144), ( 6,145), ( 6,146), ( 6,147), ( 6,148), ( 6,149), ( 6,150), ( 6,151), ( 6,152), ( 6,153), ( 6,154), ( 6,155), ( 6,156), ( 6,157), ( 6,158), ( 6,159)
( 7,144), ( 7,145), ( 7,146), ( 7,147), ( 7,148), ( 7,149), ( 7,150), ( 7,151), ( 7,152), ( 7,153), ( 7,154), ( 7,155), ( 7,156), ( 7,157), ( 7,158), ( 7,159)
( 8,144), ( 8,145), ( 8,146), ( 8,147), ( 8,148), ( 8,149), ( 8,150), ( 8,151), ( 8,152), ( 8,153), ( 8,154), ( 8,155), ( 8,156), ( 8,157), ( 8,158), ( 8,159)
( 9,144), ( 9,145), ( 9,146), ( 9,147), ( 9,148), ( 9,149), ( 9,150), ( 9,151), ( 9,152), ( 9,153), ( 9,154), ( 9,155), ( 9,156), ( 9,157), ( 9,158), ( 9,159)
(10,144), (10,145), (10,146), (10,147), (10,148), (10,149), (10,150), (10,151), (10,152), (10,153), (10,154), (10,155), (10,156), (10,157), (10,158), (10,159)
(11,144), (11,145), (11,146), (11,147), (11,148), (11,149), (11,150), (11,151), (11,152), (11,153), (11,154), (11,155), (11,156), (11,157), (11,158), (11,159)
(12,144), (12,145), (12,146), (12,147), (12,148), (12,149), (12,150), (12,151), (12,152), (12,153), (12,154), (12,155), (12,156), (12,157), (12,158), (12,159)
(13,144), (13,145), (13,146), (13,147), (13,148), (13,149), (13,150), (13,151), (13,152), (13,153), (13,154), (13,155), (13,156), (13,157), (13,158), (13,159)
(14,144), (14,145), (14,146), (14,147), (14,148), (14,149), (14,150), (14,151), (14,152), (14,153), (14,154), (14,155), (14,156), (14,157), (14,158), (14,159)
(15,144), (15,145), (15,146), (15,147), (15,148), (15,149), (15,150), (15,151), (15,152), (15,153), (15,154), (15,155), (15,156), (15,157), (15,158), (15,159)
(16,128), (16,129), (16,130), (16,131), (16,132), (16,133), (16,134), (16,135), (16,136), (16,137), (16,138), (16,139), (16,140), (16,141), (16,142), (16,143)
(17,128), (17,129), (17,130), (17,131), (17,132), (17,133), (17,134), (17,135), (17,136), (17,137), (17,138), (17,139), (17,140), (17,141), (17,142), (17,143)
(18,128), (18,129), (18,130), (18,131), (18,132), (18,133), (18,134), (18,135), (18,136), (18,137), (18,138), (18,139), (18,140), (18,141), (18,142), (18,143)
(19,128), (19,129), (19,130), (19,131), (19,132), (19,133), (19,134), (19,135), (19,136), (19,137), (19,138), (19,139), (19,140), (19,141), (19,142), (19,143)
(20,128), (20,129), (20,130), (20,131), (20,132), (20,133), (20,134), (20,135), (20,136), (20,137), (20,138), (20,139), (20,140), (20,141), (20,142), (20,143)
(21,128), (21,129), (21,130), (21,131), (21,132), (21,133), (21,134), (21,135), (21,136), (21,137), (21,138), (21,139), (21,140), (21,141), (21,142), (21,143)
(22,128), (22,129), (22,130), (22,131), (22,132), (22,133), (22,134), (22,135), (22,136), (22,137), (22,138), (22,139), (22,140), (22,141), (22,142), (22,143)
(23,128), (23,129), (23,130), (23,131), (23,132), (23,133), (23,134), (23,135), (23,136), (23,137), (23,138), (23,139), (23,140), (23,141), (23,142), (23,143)
(24,128), (24,129), (24,130), (24,131), (24,132), (24,133), (24,134), (24,135), (24,136), (24,137), (24,138), (24,139), (24,140), (24,141), (24,142), (24,143)
(25,128), (25,129), (25,130), (25,131), (25,132), (25,133), (25,134), (25,135), (25,136), (25,137), (25,138), (25,139), (25,140), (25,141), (25,142), (25,143)
(26,128), (26,129), (26,130), (26,131), (26,132), (26,133), (26,134), (26,135), (26,136), (26,137), (26,138), (26,139), (26,140), (26,141), (26,142), (26,143)
(27,128), (27,129), (27,130), (27,131), (27,132), (27,133), (27,134), (27,135), (27,136), (27,137), (27,138), (27,139), (27,140), (27,141), (27,142), (27,143)
(28,128), (28,129), (28,130), (28,131), (28,132), (28,133), (28,134), (28,135), (28,136), (28,137), (28,138), (28,139), (28,140), (28,141), (28,142), (28,143)
(29,128), (29,129), (29,130), (29,131), (29,132), (29,133), (29,134), (29,135), (29,136), (29,137), (29,138), (29,139), (29,140), (29,141), (29,142), (29,143)
(30,128), (30,129), (30,130), (30,131), (30,132), (30,133), (30,134), (30,135), (30,136), (30,137), (30,138), (30,139), (30,140), (30,141), (30,142), (30,143)
(31,128), (31,129), (31,130), (31,131), (31,132), (31,133), (31,134), (31,135), (31,136), (31,137), (31,138), (31,139), (31,140), (31,141), (31,142), (31,143)
(16,144), (16,145), (16,146), (16,147), (16,148), (16,149), (16,150), (16,151), (16,152), (16,153), (16,154), (16,155), (16,156), (16,157), (16,158), (16,159)
(17,144), (17,145), (17,146), (17,147), (17,148), (17,149), (17,150), (17,151), (17,152), (17,153), (17,154), (17,155), (17,156), (17,157), (17,158), (17,159)
(18,144), (18,145), (18,146), (18,147), (18,148), (18,149), (18,150), (18,151), (18,152), (18,153), (18,154), (18,155), (18,156), (18,157), (18,158), (18,159)
(19,144), (19,145), (19,146), (19,147), (19,148), (19,149), (19,150), (19,151), (19,152), (19,153), (19,154), (19,155), (19,156), (19,157), (19,158), (19,159)
(20,144), (20,145), (20,146), (20,147), (20,148), (20,149), (20,150), (20,151), (20,152), (20,153), (20,154), (20,155), (20,156), (20,157), (20,158), (20,159)
(21,144), (21,145), (21,146), (21,147), (21,148), (21,149), (21,150), (21,151), (21,152), (21,153), (21,154), (21,155), (21,156), (21,157), (21,158), (21,159)
(22,144), (22,145), (22,146), (22,147), (22,148), (22,149), (22,150), (22,151), (22,152), (22,153), (22,154), (22,155), (22,156), (22,157), (22,158), (22,159)
(23,144), (23,145), (23,146), (23,147), (23,148), (23,149), (23,150), (23,151), (23,152), (23,153), (23,154), (23,155), (23,156), (23,157), (23,158), (23,159)
(24,144), (24,145), (24,146), (24,147), (24,148), (24,149), (24,150), (24,151), (24,152), (24,153), (24,154), (24,155), (24,156), (24,157), (24,158), (24,159)
(25,144), (25,145), (25,146), (25,147), (25,148), (25,149), (25,150), (25,151), (25,152), (25,153), (25,154), (25,155), (25,156), (25,157), (25,158), (25,159)
(26,144), (26,145), (26,146), (26,147), (26,148), (26,149), (26,150), (26,151), (26,152), (26,153), (26,154), (26,155), (26,156), (26,157), (26,158), (26,159)
(27,144), (27,145), (27,146), (27,147), (27,148), (27,149), (27,150), (27,151), (27,152), (27,153), (27,154), (27,155), (27,156), (27,157), (27,158), (27,159)
(28,144), (28,145), (28,146), (28,147), (28,148), (28,149), (28,150), (28,151), (28,152), (28,153), (28,154), (28,155), (28,156), (28,157), (28,158), (28,159)
(29,144), (29,145), (29,146), (29,147), (29,148), (29,149), (29,150), (29,151), (29,152), (29,153), (29,154), (29,155), (29,156), (29,157), (29,158), (29,159)
(30,144), (30,145), (30,146), (30,147), (30,148), (30,149), (30,150), (30,151), (30,152), (30,153), (30,154), (30,155), (30,156), (30,157), (30,158), (30,159)
(31,144), (31,145), (31,146), (31,147), (31,148), (31,149), (31,150), (31,151), (31,152), (31,153), (31,154), (31,155), (31,156), (31,157), (31,158), (31,159)

Here we again have two blocks, but this time we have a discontinuity between the first and second contiguous blocks.

The DPAS tile layout for B is also different from A. The B tile is expected to be 16x16. But, unlike the A matrix, each register for the B matrix encodes two 16 bit values using a single 32 bit value. This is referred to as "vnni transform"ing the data and is automatically handled during the load.

WI 0 WI 1 WI 2 WI 3 WI 4 WI 5 WI 6 WI 7 WI 8 WI 9 WI 10 WI 11 WI 12 WI 13 WI 14 WI 15
0, 0 : 1, 0 0, 1 : 1, 1 0, 2 : 1, 2 0, 3 : 1, 3 0, 4 : 1, 4 0, 5 : 1, 5 0, 6 : 1, 6 0, 7 : 1, 7 0, 8 : 1, 8 0, 9 : 1 , 9 0, 10 : 1, 10 0, 11 : 1, 11 0, 12 : 1, 12 0, 13 : 1, 13 0, 14 : 1, 14 0, 15 : 1, 15
2, 0 : 3, 0 2, 1 : 3, 1 2, 2 : 3, 2 2, 3 : 3, 3 2, 4 : 3, 4 2, 5 : 3, 5 2, 6 : 3, 6 2, 7 : 3, 7 2, 8 : 3, 8 2, 9 : 3 , 9 2, 10 : 3, 10 2, 11 : 3, 11 2, 12 : 3, 12 2, 13 : 3, 13 2, 14 : 3, 14 2, 15 : 3, 15
4, 0 : 5, 0 4, 1 : 5, 1 4, 2 : 5, 2 4, 3 : 5, 3 4, 4 : 5, 4 4, 5 : 5, 5 4, 6 : 5, 6 4, 7 : 5, 7 4, 8 : 5, 8 4, 9 : 5 , 9 4, 10 : 5, 10 4, 11 : 5, 11 4, 12 : 5, 12 4, 13 : 5, 13 4, 14 : 5, 14 4, 15 : 5, 15
6, 0 : 7, 0 6, 1 : 7, 1 6, 2 : 7, 2 6, 3 : 7, 3 6, 4 : 7, 4 6, 5 : 7, 5 6, 6 : 7, 6 6, 7 : 7, 7 6, 8 : 7, 8 6, 9 : 7 , 9 6, 10 : 7, 10 6, 11 : 7, 11 6, 12 : 7, 12 6, 13 : 7, 13 6, 14 : 7, 14 6, 15 : 7, 15
8, 0 : 9, 0 8, 1 : 9, 1 8, 2 : 9, 2 8, 3 : 9, 3 8, 4 : 9, 4 8, 5 : 9, 5 8, 6 : 9, 6 8, 7 : 9, 7 8, 8 : 9, 8 8, 9 : 9 , 9 8, 10 : 9, 10 8, 11 : 9, 11 8, 12 : 9, 12 8, 13 : 9, 13 8, 14 : 9, 14 8, 15 : 9, 15
10, 0 : 11, 0 10, 1 : 11, 1 10, 2 : 11, 2 10, 3 : 11, 3 10, 4 : 11, 4 10, 5 : 11, 5 10, 6 : 11, 6 10, 7 : 11, 7 10, 8 : 11, 8 10, 9 : 11 , 9 10, 10 : 11, 10 10, 11 : 11, 11 10, 12 : 11, 12 10, 13 : 11, 13 10, 14 : 11, 14 10, 15 : 11, 15
12, 0 : 13, 0 12, 1 : 13, 1 12, 2 : 13, 2 12, 3 : 13, 3 12, 4 : 13, 4 12, 5 : 13, 5 12, 6 : 13, 6 12, 7 : 13, 7 12, 8 : 13, 8 12, 9 : 13 , 9 12, 10 : 13, 10 12, 11 : 13, 11 12, 12 : 13, 12 12, 13 : 13, 13 12, 14 : 13, 14 12, 15 : 13, 15
14, 0 : 15, 0 14, 1 : 15, 1 14, 2 : 15, 2 14, 3 : 15, 3 14, 4 : 15, 4 14, 5 : 15, 5 14, 6 : 15, 6 14, 7 : 15, 7 14, 8 : 15, 8 14, 9 : 15 , 9 14, 10 : 15, 10 14, 11 : 15, 11 14, 12 : 15, 12 14, 13 : 15, 13 14, 14 : 15, 14 14, 15 : 15, 15

As before, we start with the block load layout corresponding to a single DPAS tile:

Block load tile layout:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (0, 8)
   offset=16 -> (1, 0)
   offset=32 -> (2, 0)
   offset=64 -> (4, 0)
where out dims are: [dim0 (size 8), dim1 (size 16)]

The DPAS tile size is 16x16, but the load layout tile size is 8x16 due to the vnni transform. However, to maintain surjectivity of the layout, the offset dimension is still a linear index into the row-wise loaded data. To convert the output of the layout on a specific offset to a tensor value we multiply by the number of packed values in that dimension. For example, the 17th element (offset = 16) is index (1, 0) but the corresponding slot actually holds the tensor indices (2, 0), (3, 0). To convert to tensor indices we simply multiply the first value by the number of packed elements, and add 1 for each subsequent packed element.

0 : 0, 0
16 : 1, 0
32 : 2, 0
48 : 3, 0
64 : 4, 0
80 : 5, 0
96 : 6, 0
112 : 7, 0

We compute iterations as before. This layout has four iterations in dim0 and two iterations in dim1:

Block load tile layout after adding iterations:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (0, 8)
   offset=16 -> (1, 0)
   offset=32 -> (2, 0)
   offset=64 -> (4, 0)
 - iteration=1 -> (8, 0)
   iteration=2 -> (16, 0)
   iteration=4 -> (0, 16)
where out dims are: [dim0 (size 32), dim1 (size 32)]

Finally, as suspected when examining the DPAS layout, there are two contiguous blocks separated by a stride so we will need multiple loads. In this case hardware limitations are not a factor, so we generate two loads separated by the outer stride of 128 elements:

Block load tile layout after adding loads:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (0, 8)
   offset=16 -> (1, 0)
   offset=32 -> (2, 0)
   offset=64 -> (4, 0)
 - iteration=1 -> (0, 16)
   iteration=2 -> (8, 0)
 - load=1 -> (128, 0)
where out dims are: [dim0 (size 256), dim1 (size 32)]
0, 0, 0 : 0, 0
0, 0, 127 : 7, 15
0, 1, 0 : 0, 16
0, 1, 127 : 7, 31
0, 2, 0 : 8, 0
0, 2, 127 : 15, 15
0, 3, 0 : 8, 16
0, 3, 127 : 15, 31

1, 0, 0 : 128, 0
1, 0, 127 : 135, 15
1, 1, 0 : 128, 16
1, 1, 127 : 135, 31
1, 2, 0 : 136, 0
1, 2, 127 : 143, 15
1, 3, 0 : 136, 16
1, 3, 127 : 143, 31

Note that this layout is not surjective. Linear layouts are assumed to be surjective by default; that is, all output values should be covered by some input value. But non-surjective layouts can be created, albeit with some limitations around invertibility and use in other APIs. For the case of block loads, a non-surjective layout is necessary to accurately model the loaded data in the global input coordinate space.

AxBT

Now we will consider the AxBT GEMM kernel. We keep the A matrix size fixed and transpose the B matrix. The B matrix input becomes [4096 x 5120] and the matrix is transposed during execution of the kernel. Because the DPAS instruction does not have a transpose variant, we use the same layouts for the A and B matrices as before. We are relying on the 2D block load to compute the transpose and the shuffle vectors to put the post-transposed data into the right registers for DPAS. In the future we may modify the pass pipeline to adjust the overall layout for a dot operation with a transposed A or B matrix, but for this example it is acceptable to leave the layouts fixed so we can examine the loads.

The A matrix is not transposed and the DPAS layout is the same as the non-transposed case, therefore the loads are the same.

For the BT matrix, the transpose is computed during the load. This requires a different layout for the load. Like the non-transpose case, we start from the DPAS tile size. However, transpose only supports 32 bit matrix elements. Our bf16 elements are 16 bits, so we account for this in the load by reducing the inner dimension by a factor of two:

Block load tile layout:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (1, 0)
   offset=16 -> (2, 0)
   offset=32 -> (4, 0)
   offset=64 -> (8, 0)
where out dims are: [dim0 (size 16), dim1 (size 8)]

Note that the output data for the transposed load is not explicitly vnni transformed. However, by increasing the size of the slot to hold two elements we are essentially vnni transforming the data during the load.

For this load we have two iterations in the outer dimension:

Block load tile layout after adding iterations:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (1, 0)
   offset=16 -> (2, 0)
   offset=32 -> (4, 0)
   offset=64 -> (8, 0)
 - iteration=1 -> (16, 0)
where out dims are: [dim0 (size 32), dim1 (size 8)]

Finally, we compute the loads. Since we are limited in the amount of data we can load compared to the non-transpose case, we need more loads.

Block load tile layout after adding loads:
 - offset=1 -> (0, 1)
   offset=2 -> (0, 2)
   offset=4 -> (0, 4)
   offset=8 -> (1, 0)
   offset=16 -> (2, 0)
   offset=32 -> (4, 0)
   offset=64 -> (8, 0)
 - iteration=1 -> (16, 0)
 - load=1 -> (0, 16)
   load=2 -> (128, 0)
where out dims are: [dim0 (size 256), dim1 (size 32)]

The block load layout has size 256 x 32 after adding loads. As a sanity check, the block load layout tile size should match the block parameters for the kernel for layouts with more than one load (in the one load case we do not modify the layout dimensions). In other words, the layout can represent any used values in the load for the target operand in this block. Note that the load layout does not have the concept of "warp". Therefore, it may not represent all used values (see surjectivity discussion above). We are concerned with the data loaded from an arbitrary call to the 2D block load hardware extensions; the offsets for warp are added later (e.g. when generating the 2D block load instruction). By replicating the layout over all warps (subgroups) in the block (workgroup) and applying an offset determined by the DPAS layout, we can cover all values in the block dimension (256 x 32).