File tree 3 files changed +12
-6
lines changed
3 files changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -76,13 +76,14 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl
76
76
$ pip install jax-triton
77
77
```
78
78
79
- You can either use a stable release of ` triton ` or a nightly release.
80
-
81
79
Make sure you have a CUDA-compatible ` jax ` installed. For example you could run:
82
80
``` bash
83
81
$ pip install " jax[cuda12]"
84
82
```
85
83
84
+ ` jax-triton ` currently requires building the latest version of ` triton `
85
+ [ from source] ( https://triton-lang.org/main/getting-started/installation.html#from-source ) .
86
+
86
87
## Development
87
88
88
89
To develop ` jax-triton ` , you can clone the repo with:
Original file line number Diff line number Diff line change 20
20
import copy
21
21
import dataclasses
22
22
import functools
23
- import inspect
24
23
import os
25
24
import pprint
26
25
import tempfile
@@ -363,12 +362,19 @@ def get_or_create_triton_kernel(
363
362
alignments = [16 ] * len (arg_dtypes )
364
363
for i , _ , value in scalar_args :
365
364
alignments [i ] = value
365
+ specialize_extra = backend .get_arg_specialization
366
+ if specialize_impl := getattr (triton .runtime .jit , "specialize_impl" , None ):
367
+ # TODO(slebedev): Remove this branch once Triton 3.3 is released.
368
+ specialize_impl = functools .partial (
369
+ specialize_impl , specialize_extra = specialize_extra
370
+ )
371
+ else :
372
+ specialize_impl = triton .runtime .jit .create_specialize_impl (specialize_extra )
366
373
specialization = [
367
- triton . runtime . jit . specialize_impl (
374
+ specialize_impl (
368
375
types .SimpleNamespace (
369
376
data_ptr = lambda : alignment , dtype = arg_dtype .removeprefix ("*" )
370
377
),
371
- backend .get_arg_specialization ,
372
378
)
373
379
for arg_dtype , alignment in zip (arg_dtypes , alignments )
374
380
]
Original file line number Diff line number Diff line change 24
24
import jax_triton as jt
25
25
import numpy as np
26
26
import triton
27
- from triton .compiler import code_generator as code_gen
28
27
import triton .language as tl
29
28
30
29
config .parse_flags_with_absl ()
You can’t perform that action at this time.
0 commit comments