Skip to content

Commit 7d9f603

Browse files
committed
fix with single kernel in small input
1 parent d452bc3 commit 7d9f603

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

src/flag_gems/ops/dot.py

+30-32
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from ..utils import libentry
1010
from ..utils import triton_lang_extension as tle
1111

12-
1312
@libentry()
1413
@triton.jit
1514
def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
@@ -22,9 +21,8 @@ def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
2221
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
2322
y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
2423

25-
partial_sum = tl.sum(x * y)
26-
27-
tl.atomic_add(out_ptr, partial_sum)
24+
sum = tl.sum(x * y)
25+
tl.store(out_ptr, sum)
2826

2927

3028
@libentry()
@@ -61,33 +59,33 @@ def dot(x, y):
6159
assert x.dim() == 1, "Input must be 1D tensors"
6260

6361
N = x.shape[0]
64-
block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))
65-
66-
# if N <= 2560000:
67-
mid_size = triton.cdiv(N, block_size)
68-
block_mid = triton.next_power_of_2(mid_size)
69-
70-
grid_1 = (mid_size, 1, 1)
71-
grid_2 = (1, 1, 1)
72-
73-
mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device)
74-
out = torch.empty([], dtype=x.dtype, device=x.device)
75-
76-
with torch_device_fn.device(x.device):
77-
dot_kernel_1[grid_1](x, y, mid, N, block_size)
78-
dot_kernel_2[grid_2](mid, out, mid_size, block_mid)
79-
80-
# else:
81-
# grid_size = triton.cdiv(N, block_size)
82-
# grid = (grid_size,1,1)
83-
84-
# with torch_device_fn.device(x.device):
85-
# if x.dtype != torch.float32:
86-
# out = torch.zeros([], dtype=torch.float32, device=x.device)
87-
# dot_kernel[grid](x, y, out, N, block_size)
88-
# out = out.to(x.dtype)
89-
# else:
90-
# out = torch.zeros([], dtype=x.dtype, device=x.device)
91-
# dot_kernel[grid](x, y, out, N, block_size)
62+
63+
# Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel, and performance is better when N < 4096
64+
if N >= 4096:
65+
block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))
66+
67+
mid_size = triton.cdiv(N, block_size)
68+
block_mid = triton.next_power_of_2(mid_size)
69+
70+
grid_1 = (mid_size, 1, 1)
71+
grid_2 = (1, 1, 1)
72+
73+
mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device)
74+
out = torch.empty([], dtype=x.dtype, device=x.device)
75+
76+
with torch_device_fn.device(x.device):
77+
dot_kernel_1[grid_1](x, y, mid, N, block_size)
78+
dot_kernel_2[grid_2](mid, out, mid_size, block_mid)
79+
80+
else:
81+
block_size = triton.next_power_of_2(math.ceil(N))
82+
83+
grid = (1, 1, 1)
84+
85+
out = torch.empty([], dtype=torch.float32, device=x.device)
86+
87+
with torch_device_fn.device(x.device):
88+
dot_kernel[grid](x, y, out, N, block_size)
89+
out = out.to(x.dtype)
9290

9391
return out

0 commit comments

Comments
 (0)