Skip to content

Commit 7e9af75

Browse files
committedMar 14, 2025·
Support the latest Triton API changes
This is a backport of the fix proposed by mattjj@ in #333. I also fixed a few linter warnings to make sure at least our lint CI is green. Note that the tests will still be red once this PR lands, because Triton does not have a release incorporating all the recent API changes. I manually verified that the tests pass with Triton compiled from their main branch.
1 parent 55daa46 commit 7e9af75

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed
 

‎README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl
7676
$ pip install jax-triton
7777
```
7878

79-
You can either use a stable release of `triton` or a nightly release.
80-
8179
Make sure you have a CUDA-compatible `jax` installed. For example you could run:
8280
```bash
8381
$ pip install "jax[cuda12]"
8482
```
8583

84+
`jax-triton` currently requires building the latest version of `triton`
85+
[from source](https://triton-lang.org/main/getting-started/installation.html#from-source).
86+
8687
## Development
8788

8889
To develop `jax-triton`, you can clone the repo with:

‎jax_triton/triton_lib.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import copy
2121
import dataclasses
2222
import functools
23-
import inspect
2423
import os
2524
import pprint
2625
import tempfile
@@ -363,12 +362,19 @@ def get_or_create_triton_kernel(
363362
alignments = [16] * len(arg_dtypes)
364363
for i, _, value in scalar_args:
365364
alignments[i] = value
365+
specialize_extra = backend.get_arg_specialization
366+
if specialize_impl := getattr(triton.runtime.jit, "specialize_impl", None):
367+
# TODO(slebedev): Remove this branch once Triton 3.3 is released.
368+
specialize_impl = functools.partial(
369+
specialize_impl, specialize_extra=specialize_extra
370+
)
371+
else:
372+
specialize_impl = triton.runtime.jit.create_specialize_impl(specialize_extra)
366373
specialization = [
367-
triton.runtime.jit.specialize_impl(
374+
specialize_impl(
368375
types.SimpleNamespace(
369376
data_ptr=lambda: alignment, dtype=arg_dtype.removeprefix("*")
370377
),
371-
backend.get_arg_specialization,
372378
)
373379
for arg_dtype, alignment in zip(arg_dtypes, alignments)
374380
]

‎tests/triton_call_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import jax_triton as jt
2525
import numpy as np
2626
import triton
27-
from triton.compiler import code_generator as code_gen
2827
import triton.language as tl
2928

3029
config.parse_flags_with_absl()

0 commit comments

Comments
 (0)
Please sign in to comment.