From 7e9af75edfc744e8740feafb1bfcdb20276f4968 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev@google.com>
Date: Fri, 14 Mar 2025 11:26:05 +0000
Subject: [PATCH] Support the latest Triton API changes

This is a backport of the fix proposed by mattjj@ in jax-ml/jax-triton#333.

I also fixed a few linter warnings to make sure at least our lint CI is green.

Note that the tests will still be red once this PR lands, because Triton
does not have a release incorporating all the recent API changes. I manually
verified that the tests pass with Triton compiled from their main branch.
---
 README.md                 |  5 +++--
 jax_triton/triton_lib.py  | 12 +++++++++---
 tests/triton_call_test.py |  1 -
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/README.md b/README.md
index 88dbb33..7765a2a 100644
--- a/README.md
+++ b/README.md
@@ -76,13 +76,14 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl
 $ pip install jax-triton
 ```
 
-You can either use a stable release of `triton` or a nightly release.
-
 Make sure you have a CUDA-compatible `jax` installed. For example you could run:
 ```bash
 $ pip install "jax[cuda12]"
 ```
 
+`jax-triton` currently requires building the latest version of `triton`
+[from source](https://triton-lang.org/main/getting-started/installation.html#from-source).
+
 ## Development
 
 To develop `jax-triton`, you can clone the repo with:
diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py
index db02dd1..7f0ae8a 100644
--- a/jax_triton/triton_lib.py
+++ b/jax_triton/triton_lib.py
@@ -20,7 +20,6 @@
 import copy
 import dataclasses
 import functools
-import inspect
 import os
 import pprint
 import tempfile
@@ -363,12 +362,19 @@ def get_or_create_triton_kernel(
   alignments = [16] * len(arg_dtypes)
   for i, _, value in scalar_args:
     alignments[i] = value
+  specialize_extra = backend.get_arg_specialization
+  if specialize_impl := getattr(triton.runtime.jit, "specialize_impl", None):
+    # TODO(slebedev): Remove this branch once Triton 3.3 is released.
+    specialize_impl = functools.partial(
+        specialize_impl, specialize_extra=specialize_extra
+    )
+  else:
+    specialize_impl = triton.runtime.jit.create_specialize_impl(specialize_extra)
   specialization = [
-      triton.runtime.jit.specialize_impl(
+      specialize_impl(
           types.SimpleNamespace(
               data_ptr=lambda: alignment, dtype=arg_dtype.removeprefix("*")
           ),
-          backend.get_arg_specialization,
       )
       for arg_dtype, alignment in zip(arg_dtypes, alignments)
   ]
diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py
index 239b090..d009056 100644
--- a/tests/triton_call_test.py
+++ b/tests/triton_call_test.py
@@ -24,7 +24,6 @@
 import jax_triton as jt
 import numpy as np
 import triton
-from triton.compiler import code_generator as code_gen
 import triton.language as tl
 
 config.parse_flags_with_absl()