88
88
jnp .dtype ("uint32" ): "u32" ,
89
89
jnp .dtype ("uint16" ): "u16" ,
90
90
jnp .dtype ("uint8" ): "u8" ,
91
- # Triton defines a 'B' type, which is an alias for both i1 and bool.
92
- jnp .dtype ("bool" ): "B" ,
91
+ jnp .dtype ("bool" ): "i1" ,
93
92
}
94
93
95
94
Grid = Union [int , tuple [int ], tuple [int , int ], tuple [int , int , int ]]
@@ -353,30 +352,36 @@ def get_or_create_triton_kernel(
353
352
if num_ctas > 1 and compute_capability < 90 :
354
353
raise ValueError ("num_ctas > 1 unsupported before Hopper." )
355
354
355
+ backend = backend_init_func (device , compute_capability )
356
+
356
357
signature = {fn .arg_names [i ]: v for i , v in enumerate (arg_dtypes )}
357
358
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
358
359
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
359
360
# assumption, unless array args are include in the `do_not_specialize` list.
360
- # We replace array arguments with mock Torch tensors, to allow us to use
361
- # `JITFunction._get_config` to get the specialization_attr.
362
- mock_torch_tensor = types .SimpleNamespace (data_ptr = lambda : 16 )
363
- args_for_specialization_attr = [mock_torch_tensor ] * len (arg_dtypes )
364
- backend = backend_init_func (device , compute_capability )
365
- for i , _ , v in scalar_args :
366
- args_for_specialization_attr [i ] = v
367
-
368
- specialization_attr = backend .get_attrs_descriptor (fn .params [:len (args_for_specialization_attr )], args_for_specialization_attr ) # pylint: disable=protected-access
361
+ specialization = [
362
+ triton .runtime .jit .specialize_impl (
363
+ types .SimpleNamespace (
364
+ data_ptr = lambda : 16 , dtype = arg_dtype .removeprefix ("*" )
365
+ ),
366
+ backend .get_arg_specialization ,
367
+ )
368
+ for arg_dtype in arg_dtypes
369
+ ]
370
+ attrs = {
371
+ fn .arg_names [i ]: backend .parse_attr (attr )
372
+ for i , (_ , attr ) in enumerate (specialization )
373
+ }
369
374
constants = dict (metaparams )
370
375
constants .update ({k : None for _ , k , v in scalar_args if v is None })
371
- constants .update ({fn .arg_names [i ]: 1 for ( i ,) in specialization_attr . equal_to_1 })
376
+ constants .update ({fn .arg_names [i ]: 1 for i , _ , v in scalar_args if v == 1 })
372
377
for constant in constants :
373
378
signature [constant ] = "constexpr"
374
379
375
380
# Cache key should contain any parameter that can affect the compiler output.
376
381
cache_key = (
377
382
fn ,
378
383
tuple (signature .items ()),
379
- tuple (specialization_attr . get_fn_attrs () ),
384
+ tuple (specialization ),
380
385
tuple (constants .items ()),
381
386
num_warps ,
382
387
num_stages ,
@@ -408,46 +413,22 @@ def get_or_create_triton_kernel(
408
413
context = _triton .ir .context ()
409
414
_triton .ir .load_dialects (context )
410
415
backend .load_dialects (context )
411
- codegen_fns = backend .get_codegen_implementation ()
412
-
413
- module = (
414
- code_gen .ast_to_ttir (
415
- fn ,
416
- specialization = tc .ASTSource (
417
- fn ,
418
- constexprs = constants ,
419
- signature = signature ,
420
- attrs = specialization_attr ,
421
- ),
422
- options = options ,
423
- codegen_fns = codegen_fns ,
424
- context = context ,
425
- module_map = backend .get_module_map (),
426
- )
427
- if "module_map" in inspect .getfullargspec (code_gen .ast_to_ttir ).args
428
- # Triton changes ASTSource.ast_to_ttir to include module_map. Handle
429
- # backward compatibility here.
430
- else code_gen .ast_to_ttir (
431
- fn ,
432
- specialization = tc .ASTSource (
433
- fn ,
434
- constexprs = constants ,
435
- signature = signature ,
436
- attrs = specialization_attr ,
437
- ),
438
- options = options ,
439
- codegen_fns = codegen_fns ,
440
- context = context ,
441
- )
416
+ codegen_fns = backend .get_codegen_implementation (options )
417
+
418
+ module = code_gen .ast_to_ttir (
419
+ fn ,
420
+ tc .ASTSource (
421
+ fn , constexprs = constants , signature = signature , attrs = attrs
422
+ ),
423
+ options = options ,
424
+ codegen_fns = codegen_fns ,
425
+ context = context ,
426
+ module_map = backend .get_module_map (),
442
427
)
443
428
ttir = str (module )
444
429
445
430
compilation_result = compile_ttir_inplace (
446
- module ,
447
- backend ,
448
- options ,
449
- compute_capability ,
450
- platform
431
+ module , backend , options , compute_capability , platform
451
432
)
452
433
453
434
kernel_name = compilation_result .name
@@ -459,7 +440,7 @@ def get_or_create_triton_kernel(
459
440
with open (
460
441
f"{ _JAX_TRITON_DUMP_DIR } /{ kernel_hash } /{ kernel_name } .ptx" , "w"
461
442
) as f :
462
- f .write (compilation_result .ptx )
443
+ f .write (compilation_result .binary )
463
444
with open (
464
445
f"{ _JAX_TRITON_DUMP_DIR } /{ kernel_hash } /{ kernel_name } .ttgir" , "w"
465
446
) as f :
@@ -490,7 +471,7 @@ def get_or_create_triton_kernel(
490
471
491
472
_COMPILED_KERNEL_CACHE [cache_key ] = kernel
492
473
493
- return kernel , specialization_attr
474
+ return kernel , attrs
494
475
495
476
496
477
def triton_kernel_call_lowering (
@@ -628,15 +609,17 @@ def prune_configs(configs, named_args, **kwargs):
628
609
629
610
kernel_params = []
630
611
zeroed_params_with_sizes = dict (params ["zeroed_params_with_sizes" ])
612
+ equal_to_1 = {i for i , _ , v in scalar_args if v == 1 }
631
613
for i , (arg , dtype ) in enumerate (zip (args , arg_dtypes )):
632
614
if isinstance (arg , core .ShapedArray ):
615
+ arg_attrs = specialization_attr [fn .arg_names [i ]]
633
616
kernel_params .append (
634
617
triton_kernel_call_lib .create_array_parameter (
635
618
zeroed_params_with_sizes .get (i , 0 ),
636
- 16 if (i in specialization_attr . divisibility_16 ) else 0 ,
619
+ 16 if ([ "tt.divisibility" , 16 ] in arg_attrs ) else 0 ,
637
620
)
638
621
)
639
- elif ( i ,) not in specialization_attr . equal_to_1 :
622
+ elif i not in equal_to_1 :
640
623
kernel_params .append (
641
624
triton_kernel_call_lib .create_scalar_parameter (arg , dtype )
642
625
)
0 commit comments