Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the latest Triton API changes #334

Merged
merged 1 commit into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl
$ pip install jax-triton
```

You can either use a stable release of `triton` or a nightly release.

Make sure you have a CUDA-compatible `jax` installed. For example you could run:
```bash
$ pip install "jax[cuda12]"
```

`jax-triton` currently requires building the latest version of `triton`
[from source](https://triton-lang.org/main/getting-started/installation.html#from-source).

## Development

To develop `jax-triton`, you can clone the repo with:
Expand Down
12 changes: 9 additions & 3 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import copy
import dataclasses
import functools
import inspect
import os
import pprint
import tempfile
Expand Down Expand Up @@ -363,12 +362,19 @@ def get_or_create_triton_kernel(
alignments = [16] * len(arg_dtypes)
for i, _, value in scalar_args:
alignments[i] = value
specialize_extra = backend.get_arg_specialization
if specialize_impl := getattr(triton.runtime.jit, "specialize_impl", None):
# TODO(slebedev): Remove this branch once Triton 3.3 is released.
specialize_impl = functools.partial(
specialize_impl, specialize_extra=specialize_extra
)
else:
specialize_impl = triton.runtime.jit.create_specialize_impl(specialize_extra)
specialization = [
triton.runtime.jit.specialize_impl(
specialize_impl(
types.SimpleNamespace(
data_ptr=lambda: alignment, dtype=arg_dtype.removeprefix("*")
),
backend.get_arg_specialization,
)
for arg_dtype, alignment in zip(arg_dtypes, alignments)
]
Expand Down
1 change: 0 additions & 1 deletion tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import jax_triton as jt
import numpy as np
import triton
from triton.compiler import code_generator as code_gen
import triton.language as tl

config.parse_flags_with_absl()
Expand Down
Loading