From 7e9af75edfc744e8740feafb1bfcdb20276f4968 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <slebedev@google.com> Date: Fri, 14 Mar 2025 11:26:05 +0000 Subject: [PATCH] Support the latest Triton API changes This is a backport of the fix proposed by mattjj@ in jax-ml/jax-triton#333. I also fixed a few linter warnings to make sure at least our lint CI is green. Note that the tests will still be red once this PR lands, because Triton does not have a release incorporating all the recent API changes. I manually verified that the tests pass with Triton compiled from their main branch. --- README.md | 5 +++-- jax_triton/triton_lib.py | 12 +++++++++--- tests/triton_call_test.py | 1 - 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 88dbb33..7765a2a 100644 --- a/README.md +++ b/README.md @@ -76,13 +76,14 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl $ pip install jax-triton ``` -You can either use a stable release of `triton` or a nightly release. - Make sure you have a CUDA-compatible `jax` installed. For example you could run: ```bash $ pip install "jax[cuda12]" ``` +`jax-triton` currently requires building the latest version of `triton` +[from source](https://triton-lang.org/main/getting-started/installation.html#from-source). + ## Development To develop `jax-triton`, you can clone the repo with: diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index db02dd1..7f0ae8a 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -20,7 +20,6 @@ import copy import dataclasses import functools -import inspect import os import pprint import tempfile @@ -363,12 +362,19 @@ def get_or_create_triton_kernel( alignments = [16] * len(arg_dtypes) for i, _, value in scalar_args: alignments[i] = value + specialize_extra = backend.get_arg_specialization + if specialize_impl := getattr(triton.runtime.jit, "specialize_impl", None): + # TODO(slebedev): Remove this branch once Triton 3.3 is released. + specialize_impl = functools.partial( + specialize_impl, specialize_extra=specialize_extra + ) + else: + specialize_impl = triton.runtime.jit.create_specialize_impl(specialize_extra) specialization = [ - triton.runtime.jit.specialize_impl( + specialize_impl( types.SimpleNamespace( data_ptr=lambda: alignment, dtype=arg_dtype.removeprefix("*") ), - backend.get_arg_specialization, ) for arg_dtype, alignment in zip(arg_dtypes, alignments) ] diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 239b090..d009056 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -24,7 +24,6 @@ import jax_triton as jt import numpy as np import triton -from triton.compiler import code_generator as code_gen import triton.language as tl config.parse_flags_with_absl()