Skip to content

Commit 87b1723

Browse files
authored
Use Linear Layout to describe 2D block loads (#3708)
This PR introduces a new linear layout in the Triton XPU Load to LLVM lowering for block loads. I split the creation of the layouts out of the larger PR and focused on using the layouts to compute the `(x,y)` offsets for the 2D block load instructions to ensure correctness of the layout. The shuffle vectors are still being generated using existing loop variables. The layout describes the block load in terms of three input parameters: * `offset` which is the 1D offset into the loaded data for a single DPAS invocation inside a sub-group * `iteration` which identifies the DPAS invocation when multiple DPAS invocations share a single load * `load` which identifies the load index when multiple loads occur for a given operand The output of the layout function identifies the global (x,y) tensor coordinate within a given load. This was designed to allow composition of the DPAS layout and the load layout to go from offset, iteration, load to block, warp, lane, register or vice versa. Currently the block load / tile layout is implemented within the existing loop structure. But, the layout was designed to be used to generate the 2D block loads by iterating over layout parameters. The existing loop structure is still in place and debug info can be enabled which prints the previously generated values and the linear layout values for easy debugging. I am planning to generate the shuffle vectors using composition of layouts between the DPAS layout and load layout next. The linear layout is used by default but can be disabled via a flag for debugging. cc #3008 supersedes #3487
1 parent 75c82bd commit 87b1723

File tree

9 files changed

+859
-37
lines changed

9 files changed

+859
-37
lines changed

docs/BLOCK_LOADS_LAYOUT.md

Lines changed: 483 additions & 0 deletions
Large diffs are not rendered by default.

python/test/unit/intel/test_block_load.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from triton._internal_testing import is_xpu
77

88

9-
@pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [64, 64], [64, 32], [32, 32]])
9+
@pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32]])
1010
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
1111
@pytest.mark.parametrize("transpose", [True, False])
1212
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
@@ -15,6 +15,8 @@
1515
def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pathlib.Path):
1616
# modify the layouts to ensure the correct OCL/SPIRV intrinsic is called for each datatype
1717
if dtype_str == "int8":
18+
if M == 128 and N == 16 or N == 8:
19+
pytest.skip("TODO: test fails verification")
1820
A_width = 2
1921
B_width = 4
2022
layouts = "#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}>"
@@ -23,6 +25,8 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
2325
B_width = 1
2426
layouts = "#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>"
2527
else:
28+
if M == 128 and N == 8:
29+
pytest.skip("TODO: test fails verification")
2630
A_width = 1
2731
B_width = 2
2832
layouts = "#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>"
@@ -73,5 +77,5 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
7377
kernel = triton.compile(str(temp_file))
7478

7579
kernel[(1, 1, 1)](a, x, b, y)
76-
80+
#import pdb; pdb.set_trace()
7781
assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y)

third_party/intel/backend/compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class XPUOptions:
6060
generate_native_code: bool = False
6161
advanced_path: bool = False
6262
one_matrix_per_load_for_bt: bool = False
63+
enable_tile_load_linear_layout: bool = True
6364

6465
def __post_init__(self):
6566
default_libdir = Path(__file__).parent / 'lib'
@@ -187,6 +188,7 @@ def parse_target(self, tgt_prop) -> dict:
187188
def parse_options(self, opts) -> Any:
188189
args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts}
189190
args["allow_fp8e4nv"] = True
191+
args["enable_tile_load_linear_layout"] = os.getenv("TRITON_XPU_ENABLE_TILE_LOAD_LINEAR_LAYOUT", "1") == "1"
190192
return XPUOptions(**args)
191193

192194
def pack_metadata(self, metadata):
@@ -344,7 +346,8 @@ def make_llir(src, metadata, options):
344346
# being used, e.g., convert_layout.
345347
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
346348
passes.ttgpuir.add_allocate_shared_memory(pm)
347-
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt)
349+
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt,
350+
options.enable_tile_load_linear_layout)
348351
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)
349352
passes.convert.add_arith_to_llvmir(pm)
350353
passes.common.add_canonicalizer(pm)

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def ConvertTritonIntelGPUToLLVM
2222
Option<"oneMatrixPerLoadForBT", "one_matrix_per_load_for_bt",
2323
"bool", /*default*/"false",
2424
"Only load one DPAS operands per load for transposed B matrix">,
25+
Option<"useTileLoadLinearLayout", "use_tile_load_linear_layout",
26+
"bool", /*default*/"true",
27+
"Use linear layouts to generate the tile load sizes and offsets">
2528
];
2629
}
2730

0 commit comments

Comments
 (0)