Skip to content

Commit 59f5703

Browse files
chsiggGoogle-ML-Automation
authored andcommitted
PiperOrigin-RevId: 717875236
1 parent 907555f commit 59f5703

File tree

2 files changed

+36
-91
lines changed

2 files changed

+36
-91
lines changed

jax_triton/triton_lib.py

+36-53
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
jnp.dtype("uint32"): "u32",
8989
jnp.dtype("uint16"): "u16",
9090
jnp.dtype("uint8"): "u8",
91-
# Triton defines a 'B' type, which is an alias for both i1 and bool.
92-
jnp.dtype("bool"): "B",
91+
jnp.dtype("bool"): "i1",
9392
}
9493

9594
Grid = Union[int, tuple[int], tuple[int, int], tuple[int, int, int]]
@@ -353,30 +352,36 @@ def get_or_create_triton_kernel(
353352
if num_ctas > 1 and compute_capability < 90:
354353
raise ValueError("num_ctas > 1 unsupported before Hopper.")
355354

355+
backend = backend_init_func(device, compute_capability)
356+
356357
signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)}
357358
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
358359
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
359360
# assumption, unless array args are include in the `do_not_specialize` list.
360-
# We replace array arguments with mock Torch tensors, to allow us to use
361-
# `JITFunction._get_config` to get the specialization_attr.
362-
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
363-
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
364-
backend = backend_init_func(device, compute_capability)
365-
for i, _, v in scalar_args:
366-
args_for_specialization_attr[i] = v
367-
368-
specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access
361+
specialization = [
362+
triton.runtime.jit.specialize_impl(
363+
types.SimpleNamespace(
364+
data_ptr=lambda: 16, dtype=arg_dtype.removeprefix("*")
365+
),
366+
backend.get_arg_specialization,
367+
)
368+
for arg_dtype in arg_dtypes
369+
]
370+
attrs = {
371+
fn.arg_names[i]: backend.parse_attr(attr)
372+
for i, (_, attr) in enumerate(specialization)
373+
}
369374
constants = dict(metaparams)
370375
constants.update({k: None for _, k, v in scalar_args if v is None})
371-
constants.update({fn.arg_names[i]: 1 for (i,) in specialization_attr.equal_to_1})
376+
constants.update({fn.arg_names[i]: 1 for i, _, v in scalar_args if v == 1})
372377
for constant in constants:
373378
signature[constant] = "constexpr"
374379

375380
# Cache key should contain any parameter that can affect the compiler output.
376381
cache_key = (
377382
fn,
378383
tuple(signature.items()),
379-
tuple(specialization_attr.get_fn_attrs()),
384+
tuple(specialization),
380385
tuple(constants.items()),
381386
num_warps,
382387
num_stages,
@@ -408,46 +413,22 @@ def get_or_create_triton_kernel(
408413
context = _triton.ir.context()
409414
_triton.ir.load_dialects(context)
410415
backend.load_dialects(context)
411-
codegen_fns = backend.get_codegen_implementation()
412-
413-
module = (
414-
code_gen.ast_to_ttir(
415-
fn,
416-
specialization=tc.ASTSource(
417-
fn,
418-
constexprs=constants,
419-
signature=signature,
420-
attrs=specialization_attr,
421-
),
422-
options=options,
423-
codegen_fns=codegen_fns,
424-
context=context,
425-
module_map=backend.get_module_map(),
426-
)
427-
if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args
428-
# Triton changes ASTSource.ast_to_ttir to include module_map. Handle
429-
# backward compatibility here.
430-
else code_gen.ast_to_ttir(
431-
fn,
432-
specialization=tc.ASTSource(
433-
fn,
434-
constexprs=constants,
435-
signature=signature,
436-
attrs=specialization_attr,
437-
),
438-
options=options,
439-
codegen_fns=codegen_fns,
440-
context=context,
441-
)
416+
codegen_fns = backend.get_codegen_implementation(options)
417+
418+
module = code_gen.ast_to_ttir(
419+
fn,
420+
tc.ASTSource(
421+
fn, constexprs=constants, signature=signature, attrs=attrs
422+
),
423+
options=options,
424+
codegen_fns=codegen_fns,
425+
context=context,
426+
module_map=backend.get_module_map(),
442427
)
443428
ttir = str(module)
444429

445430
compilation_result = compile_ttir_inplace(
446-
module,
447-
backend,
448-
options,
449-
compute_capability,
450-
platform
431+
module, backend, options, compute_capability, platform
451432
)
452433

453434
kernel_name = compilation_result.name
@@ -459,7 +440,7 @@ def get_or_create_triton_kernel(
459440
with open(
460441
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ptx", "w"
461442
) as f:
462-
f.write(compilation_result.ptx)
443+
f.write(compilation_result.binary)
463444
with open(
464445
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ttgir", "w"
465446
) as f:
@@ -490,7 +471,7 @@ def get_or_create_triton_kernel(
490471

491472
_COMPILED_KERNEL_CACHE[cache_key] = kernel
492473

493-
return kernel, specialization_attr
474+
return kernel, attrs
494475

495476

496477
def triton_kernel_call_lowering(
@@ -628,15 +609,17 @@ def prune_configs(configs, named_args, **kwargs):
628609

629610
kernel_params = []
630611
zeroed_params_with_sizes = dict(params["zeroed_params_with_sizes"])
612+
equal_to_1 = {i for i, _, v in scalar_args if v == 1}
631613
for i, (arg, dtype) in enumerate(zip(args, arg_dtypes)):
632614
if isinstance(arg, core.ShapedArray):
615+
arg_attrs = specialization_attr[fn.arg_names[i]]
633616
kernel_params.append(
634617
triton_kernel_call_lib.create_array_parameter(
635618
zeroed_params_with_sizes.get(i, 0),
636-
16 if (i in specialization_attr.divisibility_16) else 0,
619+
16 if (["tt.divisibility", 16] in arg_attrs) else 0,
637620
)
638621
)
639-
elif (i,) not in specialization_attr.equal_to_1:
622+
elif i not in equal_to_1:
640623
kernel_params.append(
641624
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
642625
)

tests/triton_call_test.py

-38
Original file line numberDiff line numberDiff line change
@@ -531,44 +531,6 @@ def test_autotune_with_input_output_aliasing(self):
531531
out = add(x, y, kernel=kernel, input_output_aliases={0: 0})
532532
np.testing.assert_allclose(out, expected)
533533

534-
def test_specialization(self):
535-
do_not_specialize = (
536-
0, # a_ptr
537-
2, # M
538-
6, # stride_ak
539-
7, # stride_bk
540-
11, # c_ptr
541-
)
542-
kernel = triton.jit(do_not_specialize=do_not_specialize)(matmul_kernel.fn)
543-
544-
m, n, k = 128, 128, 99
545-
x, y = create_random_inputs([m, k], [k, n])
546-
547-
with mock.patch.object(code_gen, "ast_to_ttir") as mock_compile:
548-
try:
549-
_ = matmul(
550-
x,
551-
y,
552-
kernel=kernel,
553-
BLOCK_SIZE_M=32,
554-
BLOCK_SIZE_N=32,
555-
BLOCK_SIZE_K=32,
556-
# K_EXACTLY_DIVISIBLE_BY_BLOCK=False,
557-
)
558-
except TypeError:
559-
pass # Error thrown as the mocked method's return value is invalid.
560-
561-
mock_compile.assert_called_once()
562-
specialization = mock_compile.call_args[1]['specialization']
563-
564-
# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
565-
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
566-
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
567-
self.assertEqual(specialization.attrs.divisibility_16, [(1,), (3,), (9,)])
568-
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
569-
# specialize" leaving `stride_{bn,cn}`.
570-
self.assertEqual(specialization.attrs.equal_to_1, [(8,), (10,)])
571-
572534

573535
if __name__ == "__main__":
574536
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"

0 commit comments

Comments
 (0)