@@ -49,43 +49,43 @@ def __get__(self, owner_self, owner_cls):
49
49
class block_idx :
50
50
@classproperty
51
51
def x (cls ):
52
- return _block_id ("x" )
52
+ return _block_id ("x" , loc = get_user_code_loc () )
53
53
54
54
@classproperty
55
55
def y (cls ):
56
- return _block_id ("y" )
56
+ return _block_id ("y" , loc = get_user_code_loc () )
57
57
58
58
@classproperty
59
59
def z (cls ):
60
- return _block_id ("z" )
60
+ return _block_id ("z" , loc = get_user_code_loc () )
61
61
62
62
63
63
class block_dim :
64
64
@classproperty
65
65
def x (cls ):
66
- return _block_dim ("x" )
66
+ return _block_dim ("x" , loc = get_user_code_loc () )
67
67
68
68
@classproperty
69
69
def y (cls ):
70
- return _block_dim ("y" )
70
+ return _block_dim ("y" , loc = get_user_code_loc () )
71
71
72
72
@classproperty
73
73
def z (cls ):
74
- return _block_dim ("z" )
74
+ return _block_dim ("z" , loc = get_user_code_loc () )
75
75
76
76
77
77
class thread_idx :
78
78
@classproperty
79
79
def x (cls ):
80
- return _thread_id ("x" )
80
+ return _thread_id ("x" , loc = get_user_code_loc () )
81
81
82
82
@classproperty
83
83
def y (cls ):
84
- return _thread_id ("y" )
84
+ return _thread_id ("y" , loc = get_user_code_loc () )
85
85
86
86
@classproperty
87
87
def z (cls ):
88
- return _thread_id ("z" )
88
+ return _thread_id ("z" , loc = get_user_code_loc () )
89
89
90
90
91
91
def thread_id ():
@@ -222,6 +222,8 @@ def __init__(
222
222
loc = None ,
223
223
ip = None ,
224
224
):
225
+ if loc is None :
226
+ loc = get_user_code_loc ()
225
227
super ().__init__ (
226
228
function_type = function_type ,
227
229
arg_attrs = arg_attrs ,
@@ -301,10 +303,10 @@ def launch_(
301
303
):
302
304
if loc is None :
303
305
loc = get_user_code_loc ()
304
- for size in [grid_size , block_size ]:
305
- for i , s in enumerate (size ):
306
- if isinstance (s , int ):
307
- size [i ] = constant (s , index = True )
306
+ for size in [grid_size , block_size ]:
307
+ for i , s in enumerate (size ):
308
+ if isinstance (s , int ):
309
+ size [i ] = constant (s , index = True )
308
310
launch_op = LaunchOp (
309
311
grid_size ,
310
312
block_size ,
@@ -371,13 +373,16 @@ def __call__(
371
373
async_dependencies = None ,
372
374
dynamic_shared_memory_size : Optional [Value ] = None ,
373
375
stream = None ,
376
+ loc = None ,
377
+ ip = None ,
374
378
):
375
379
for size in [grid_size , block_size ]:
376
380
for i , s in enumerate (size ):
377
381
if isinstance (s , int ):
378
382
size [i ] = constant (s , index = True )
379
383
380
- loc = get_user_code_loc ()
384
+ if loc is None :
385
+ loc = get_user_code_loc ()
381
386
return get_op_result_or_op_results (
382
387
LaunchFuncOp (
383
388
(
@@ -469,6 +474,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):
469
474
470
475
471
476
def all_reduce_ (value : Value , * , op = None , uniform = None , loc = None , ip = None ):
477
+ if loc is None :
478
+ loc = get_user_code_loc ()
472
479
return get_op_result_or_op_results (
473
480
all_reduce__ (value , op = op , uniform = uniform , loc = loc , ip = ip )
474
481
)
@@ -577,15 +584,18 @@ def get_compile_object_bytes(compiled_module):
577
584
_printf = printf
578
585
579
586
580
- def printf (format , * args ):
581
- loc = get_user_code_loc ()
582
- return _printf (format = format , args = args , loc = loc )
587
+ def printf (format , * args , loc = None , ip = None ):
588
+ if loc is None :
589
+ loc = get_user_code_loc ()
590
+ return _printf (format = format , args = args , loc = loc , ip = ip )
583
591
584
592
585
593
_dynamic_shared_memory = dynamic_shared_memory
586
594
587
595
588
596
def dynamic_shared_memory (* , int = False , loc = None , ip = None ):
597
+ if loc is None :
598
+ loc = get_user_code_loc ()
589
599
return _dynamic_shared_memory (
590
600
T .memref (
591
601
ShapedType .get_dynamic_size (),
@@ -611,3 +621,10 @@ def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
611
621
if isinstance (value , (int , float , bool )):
612
622
value = constant (value , type = dst .type .element_type )
613
623
return _memset (async_token , async_dependencies , dst , value , loc = loc , ip = ip )
624
+
625
+
626
+ def barrier (* , loc = None , ip = None ):
627
+ if loc is None :
628
+ loc = get_user_code_loc ()
629
+
630
+ return BarrierOp (loc = loc , ip = ip )
0 commit comments