Skip to content

Commit 55daa46

Browse files
Add support for scratch buffers in jax_triton.
This is required to use device-side TMA descriptors. PiperOrigin-RevId: 736650031
1 parent 2d4e2eb commit 55daa46

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

jax_triton/triton_lib.py

+19-30
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import copy
2121
import dataclasses
2222
import functools
23+
import inspect
2324
import os
2425
import pprint
2526
import tempfile
@@ -34,7 +35,6 @@
3435
from jax._src import state
3536
from jax._src import util
3637
from jax._src.lib import gpu_triton as triton_kernel_call_lib
37-
from jax._src.lib import version as jaxlib_version
3838
from jax._src.lib.mlir import ir
3939
import jax.dlpack
4040
import jax.extend as jex
@@ -176,7 +176,6 @@ class CompilationResult:
176176
binary: str
177177
name: str
178178
shared_mem_bytes: int
179-
global_scratch_bytes: int
180179
cluster_dims: tuple
181180
ttgir: str | None
182181
llir: str | None
@@ -252,14 +251,15 @@ def compile_ttir_to_ptx_inplace(
252251
)
253252
if cuda_options.debug:
254253
print(ptx)
254+
name = metadata["name"]
255+
cluster_dims = metadata["cluster_dims"]
255256
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
256257
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
257258
return CompilationResult(
258259
binary=ptx,
259-
name=metadata["name"],
260+
name=name,
260261
shared_mem_bytes=shared_mem_bytes,
261-
global_scratch_bytes=metadata["global_scratch_size"],
262-
cluster_dims=metadata["cluster_dims"],
262+
cluster_dims=cluster_dims,
263263
ttgir=ttgir,
264264
llir=llir,
265265
)
@@ -318,7 +318,6 @@ def compile_ttir_to_hsaco_inplace(
318318
binary=hsaco_path,
319319
name=name,
320320
shared_mem_bytes=shared_mem_bytes,
321-
global_scratch_bytes=0,
322321
cluster_dims=cluster_dims,
323322
ttgir=ttgir,
324323
llir=llir,
@@ -341,7 +340,7 @@ def get_or_create_triton_kernel(
341340
enable_fp_fusion,
342341
metaparams,
343342
dump: bool,
344-
) -> tuple[triton_kernel_call_lib.TritonKernel, Any, int]:
343+
) -> tuple[triton_kernel_call_lib.TritonKernel, Any]:
345344
if num_warps is None:
346345
num_warps = 4
347346
if num_stages is None:
@@ -395,7 +394,7 @@ def get_or_create_triton_kernel(
395394
compute_capability,
396395
enable_fp_fusion,
397396
)
398-
kernel, scratch_bytes = _COMPILED_KERNEL_CACHE.get(cache_key, (None, 0))
397+
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)
399398

400399
if kernel is None:
401400
opts = {
@@ -474,10 +473,10 @@ def get_or_create_triton_kernel(
474473
compute_capability,
475474
*compilation_result.cluster_dims,
476475
)
477-
scratch_bytes = compilation_result.global_scratch_bytes
478-
_COMPILED_KERNEL_CACHE[cache_key] = (kernel, scratch_bytes)
479476

480-
return kernel, attrs, scratch_bytes
477+
_COMPILED_KERNEL_CACHE[cache_key] = kernel
478+
479+
return kernel, attrs
481480

482481

483482
def triton_kernel_call_lowering(
@@ -597,10 +596,8 @@ def prune_configs(configs, named_args, **kwargs):
597596
)
598597

599598
kernel_calls = []
600-
max_scratch_bytes = 0
601599
for params in config_params:
602-
grid_x, grid_y, grid_z = params["grid"]
603-
kernel, specialization_attr, scratch_bytes = get_or_create_triton_kernel(
600+
kernel, specialization_attr = get_or_create_triton_kernel(
604601
backend_init_func,
605602
ctx.module_context.platforms[0],
606603
fn,
@@ -614,8 +611,6 @@ def prune_configs(configs, named_args, **kwargs):
614611
metaparams=dict(params["metaparams"]),
615612
dump=debug,
616613
)
617-
scratch_bytes *= grid_x * grid_y * grid_z
618-
max_scratch_bytes = max(max_scratch_bytes, scratch_bytes)
619614

620615
kernel_params = []
621616
zeroed_params_with_sizes = dict(params["zeroed_params_with_sizes"])
@@ -636,15 +631,14 @@ def prune_configs(configs, named_args, **kwargs):
636631

637632
kernel_calls.append(
638633
triton_kernel_call_lib.TritonKernelCall(
639-
kernel, grid_x, grid_y, grid_z, kernel_params
634+
kernel,
635+
params["grid"][0],
636+
params["grid"][1],
637+
params["grid"][2],
638+
kernel_params,
640639
)
641640
)
642641

643-
if max_scratch_bytes > 0 and jaxlib_version < (0, 5, 3):
644-
raise NotImplementedError(
645-
"Triton kernels with scratch buffers are not supported in JAX < 0.5.3."
646-
)
647-
648642
if len(kernel_calls) > 1:
649643
named_scalar_args = {fn.arg_names[i]: v for i, _, v in scalar_args}
650644
input_output_aliases_with_sizes = tuple(
@@ -663,21 +657,16 @@ def prune_configs(configs, named_args, **kwargs):
663657
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
664658
for shape in out_shapes
665659
]
666-
667-
u8 = mlir.dtype_to_ir_type(jnp.dtype(jnp.uint8))
668-
scratch_type = ir.RankedTensorType.get([max_scratch_bytes], u8)
669-
scratch_layout = [0]
670660
call_proto = kernel_call.to_proto(kernel_call_name, serialized_metadata)
671-
results = mlir.custom_call(
661+
return mlir.custom_call(
672662
call_target_name=custom_call_target_name,
673-
result_types=out_types + [scratch_type],
663+
result_types=out_types,
674664
operands=array_args,
675665
backend_config=zlib.compress(call_proto),
676666
operand_layouts=avals_to_layouts(ctx.avals_in),
677-
result_layouts=avals_to_layouts(ctx.avals_out) + [scratch_layout],
667+
result_layouts=avals_to_layouts(ctx.avals_out),
678668
operand_output_aliases=dict(input_output_aliases),
679669
).results
680-
return results[:-1] # Remove scratch buffer.
681670

682671
mlir.register_lowering(
683672
triton_kernel_call_p,

0 commit comments

Comments
 (0)