9
9
from ..utils import libentry
10
10
from ..utils import triton_lang_extension as tle
11
11
12
-
13
12
@libentry ()
14
13
@triton .jit
15
14
def dot_kernel (x_ptr , y_ptr , out_ptr , N , BLOCK_SIZE : tl .constexpr ):
@@ -22,9 +21,8 @@ def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
22
21
x = tl .load (x_ptr + offsets , mask = mask , other = 0.0 ).to (tl .float32 )
23
22
y = tl .load (y_ptr + offsets , mask = mask , other = 0.0 ).to (tl .float32 )
24
23
25
- partial_sum = tl .sum (x * y )
26
-
27
- tl .atomic_add (out_ptr , partial_sum )
24
+ sum = tl .sum (x * y )
25
+ tl .store (out_ptr , sum )
28
26
29
27
30
28
@libentry ()
@@ -61,33 +59,33 @@ def dot(x, y):
61
59
assert x .dim () == 1 , "Input must be 1D tensors"
62
60
63
61
N = x .shape [0 ]
64
- block_size = triton . next_power_of_2 ( math . ceil ( math . sqrt ( N )))
65
-
66
- # if N <= 2560000 :
67
- mid_size = triton .cdiv ( N , block_size )
68
- block_mid = triton . next_power_of_2 ( mid_size )
69
-
70
- grid_1 = (mid_size , 1 , 1 )
71
- grid_2 = ( 1 , 1 , 1 )
72
-
73
- mid = torch . empty (( mid_size ,), dtype = torch . float32 , device = x . device )
74
- out = torch . empty ([], dtype = x . dtype , device = x . device )
75
-
76
- with torch_device_fn . device ( x . device ):
77
- dot_kernel_1 [ grid_1 ]( x , y , mid , N , block_size )
78
- dot_kernel_2 [ grid_2 ]( mid , out , mid_size , block_mid )
79
-
80
- # else:
81
- # grid_size = triton.cdiv(N, block_size)
82
- # grid = (grid_size,1,1)
83
-
84
- # with torch_device_fn.device(x.device):
85
- # if x.dtype != torch.float32:
86
- # out = torch.zeros([], dtype=torch.float32, device=x.device)
87
- # dot_kernel[grid](x, y, out, N, block_size )
88
- # out = out.to(x.dtype)
89
- # else :
90
- # out = torch.zeros([], dtype=x.dtype, device=x.device )
91
- # dot_kernel[grid](x, y, out, N, block_size )
62
+
63
+ # Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel, and performance is better when N < 4096
64
+ if N >= 4096 :
65
+ block_size = triton .next_power_of_2 ( math . ceil ( math . sqrt ( N )) )
66
+
67
+ mid_size = triton . cdiv ( N , block_size )
68
+ block_mid = triton . next_power_of_2 (mid_size )
69
+
70
+ grid_1 = ( mid_size , 1 , 1 )
71
+ grid_2 = ( 1 , 1 , 1 )
72
+
73
+ mid = torch . empty (( mid_size ,), dtype = torch . float32 , device = x . device )
74
+ out = torch . empty ([], dtype = x . dtype , device = x . device )
75
+
76
+ with torch_device_fn . device ( x . device ):
77
+ dot_kernel_1 [ grid_1 ]( x , y , mid , N , block_size )
78
+ dot_kernel_2 [ grid_2 ]( mid , out , mid_size , block_mid )
79
+
80
+ else :
81
+ block_size = triton . next_power_of_2 ( math . ceil ( N ))
82
+
83
+ grid = ( 1 , 1 , 1 )
84
+
85
+ out = torch . empty ([], dtype = torch . float32 , device = x . device )
86
+
87
+ with torch_device_fn . device ( x . device ) :
88
+ dot_kernel [ grid ]( x , y , out , N , block_size )
89
+ out = out . to ( x . dtype )
92
90
93
91
return out
0 commit comments