20
20
import copy
21
21
import dataclasses
22
22
import functools
23
+ import inspect
23
24
import os
24
25
import pprint
25
26
import tempfile
34
35
from jax ._src import state
35
36
from jax ._src import util
36
37
from jax ._src .lib import gpu_triton as triton_kernel_call_lib
37
- from jax ._src .lib import version as jaxlib_version
38
38
from jax ._src .lib .mlir import ir
39
39
import jax .dlpack
40
40
import jax .extend as jex
@@ -176,7 +176,6 @@ class CompilationResult:
176
176
binary : str
177
177
name : str
178
178
shared_mem_bytes : int
179
- global_scratch_bytes : int
180
179
cluster_dims : tuple
181
180
ttgir : str | None
182
181
llir : str | None
@@ -252,14 +251,15 @@ def compile_ttir_to_ptx_inplace(
252
251
)
253
252
if cuda_options .debug :
254
253
print (ptx )
254
+ name = metadata ["name" ]
255
+ cluster_dims = metadata ["cluster_dims" ]
255
256
ttgir = str (ttgir ) if _JAX_TRITON_DUMP_DIR else None
256
257
llir = str (llir ) if _JAX_TRITON_DUMP_DIR else None
257
258
return CompilationResult (
258
259
binary = ptx ,
259
- name = metadata [ " name" ] ,
260
+ name = name ,
260
261
shared_mem_bytes = shared_mem_bytes ,
261
- global_scratch_bytes = metadata ["global_scratch_size" ],
262
- cluster_dims = metadata ["cluster_dims" ],
262
+ cluster_dims = cluster_dims ,
263
263
ttgir = ttgir ,
264
264
llir = llir ,
265
265
)
@@ -318,7 +318,6 @@ def compile_ttir_to_hsaco_inplace(
318
318
binary = hsaco_path ,
319
319
name = name ,
320
320
shared_mem_bytes = shared_mem_bytes ,
321
- global_scratch_bytes = 0 ,
322
321
cluster_dims = cluster_dims ,
323
322
ttgir = ttgir ,
324
323
llir = llir ,
@@ -341,7 +340,7 @@ def get_or_create_triton_kernel(
341
340
enable_fp_fusion ,
342
341
metaparams ,
343
342
dump : bool ,
344
- ) -> tuple [triton_kernel_call_lib .TritonKernel , Any , int ]:
343
+ ) -> tuple [triton_kernel_call_lib .TritonKernel , Any ]:
345
344
if num_warps is None :
346
345
num_warps = 4
347
346
if num_stages is None :
@@ -395,7 +394,7 @@ def get_or_create_triton_kernel(
395
394
compute_capability ,
396
395
enable_fp_fusion ,
397
396
)
398
- kernel , scratch_bytes = _COMPILED_KERNEL_CACHE .get (cache_key , ( None , 0 ) )
397
+ kernel = _COMPILED_KERNEL_CACHE .get (cache_key )
399
398
400
399
if kernel is None :
401
400
opts = {
@@ -474,10 +473,10 @@ def get_or_create_triton_kernel(
474
473
compute_capability ,
475
474
* compilation_result .cluster_dims ,
476
475
)
477
- scratch_bytes = compilation_result .global_scratch_bytes
478
- _COMPILED_KERNEL_CACHE [cache_key ] = (kernel , scratch_bytes )
479
476
480
- return kernel , attrs , scratch_bytes
477
+ _COMPILED_KERNEL_CACHE [cache_key ] = kernel
478
+
479
+ return kernel , attrs
481
480
482
481
483
482
def triton_kernel_call_lowering (
@@ -597,10 +596,8 @@ def prune_configs(configs, named_args, **kwargs):
597
596
)
598
597
599
598
kernel_calls = []
600
- max_scratch_bytes = 0
601
599
for params in config_params :
602
- grid_x , grid_y , grid_z = params ["grid" ]
603
- kernel , specialization_attr , scratch_bytes = get_or_create_triton_kernel (
600
+ kernel , specialization_attr = get_or_create_triton_kernel (
604
601
backend_init_func ,
605
602
ctx .module_context .platforms [0 ],
606
603
fn ,
@@ -614,8 +611,6 @@ def prune_configs(configs, named_args, **kwargs):
614
611
metaparams = dict (params ["metaparams" ]),
615
612
dump = debug ,
616
613
)
617
- scratch_bytes *= grid_x * grid_y * grid_z
618
- max_scratch_bytes = max (max_scratch_bytes , scratch_bytes )
619
614
620
615
kernel_params = []
621
616
zeroed_params_with_sizes = dict (params ["zeroed_params_with_sizes" ])
@@ -636,15 +631,14 @@ def prune_configs(configs, named_args, **kwargs):
636
631
637
632
kernel_calls .append (
638
633
triton_kernel_call_lib .TritonKernelCall (
639
- kernel , grid_x , grid_y , grid_z , kernel_params
634
+ kernel ,
635
+ params ["grid" ][0 ],
636
+ params ["grid" ][1 ],
637
+ params ["grid" ][2 ],
638
+ kernel_params ,
640
639
)
641
640
)
642
641
643
- if max_scratch_bytes > 0 and jaxlib_version < (0 , 5 , 3 ):
644
- raise NotImplementedError (
645
- "Triton kernels with scratch buffers are not supported in JAX < 0.5.3."
646
- )
647
-
648
642
if len (kernel_calls ) > 1 :
649
643
named_scalar_args = {fn .arg_names [i ]: v for i , _ , v in scalar_args }
650
644
input_output_aliases_with_sizes = tuple (
@@ -663,21 +657,16 @@ def prune_configs(configs, named_args, **kwargs):
663
657
ir .RankedTensorType .get (shape .shape , mlir .dtype_to_ir_type (shape .dtype ))
664
658
for shape in out_shapes
665
659
]
666
-
667
- u8 = mlir .dtype_to_ir_type (jnp .dtype (jnp .uint8 ))
668
- scratch_type = ir .RankedTensorType .get ([max_scratch_bytes ], u8 )
669
- scratch_layout = [0 ]
670
660
call_proto = kernel_call .to_proto (kernel_call_name , serialized_metadata )
671
- results = mlir .custom_call (
661
+ return mlir .custom_call (
672
662
call_target_name = custom_call_target_name ,
673
- result_types = out_types + [ scratch_type ] ,
663
+ result_types = out_types ,
674
664
operands = array_args ,
675
665
backend_config = zlib .compress (call_proto ),
676
666
operand_layouts = avals_to_layouts (ctx .avals_in ),
677
- result_layouts = avals_to_layouts (ctx .avals_out ) + [ scratch_layout ] ,
667
+ result_layouts = avals_to_layouts (ctx .avals_out ),
678
668
operand_output_aliases = dict (input_output_aliases ),
679
669
).results
680
- return results [:- 1 ] # Remove scratch buffer.
681
670
682
671
mlir .register_lowering (
683
672
triton_kernel_call_p ,
0 commit comments