Add Op dot#430
Conversation
src/flag_gems/ops/dot.py
Outdated
|
|
||
| with torch_device_fn.device(x.device): | ||
| dot_kernel_1[grid_1](x, y, mid, N, block_size) | ||
| dot_kernel_2[grid_2](mid, out, mid_size, block_mid) |
There was a problem hiding this comment.
I think it's better to take tensor stride into consideration. but it's a good implementation!
src/flag_gems/ops/dot.py
Outdated
| dot_kernel_1[grid_1](x, y, mid, N, block_size) | ||
| dot_kernel_2[grid_2](mid, out, mid_size, block_mid) |
There was a problem hiding this comment.
Can we resort to a single persistent kernel when the input numel is small enough?
There was a problem hiding this comment.
There was a problem hiding this comment.
I probably didnt make myself clear. What I suggested is adding a one pass branch to handle small input. We don't have to use atomic_add on either branch. The two pass branch still exists.
|
@wlxjhyf, thanks for contributing to flaggems. Please resolve the conversions and complete this PR at your earliest convenience. |
I'm sorry I just saw it, I'll do it right now |
Don't be sorry. We are very grateful to you for your volunteering! |
StrongSpoon
left a comment
There was a problem hiding this comment.
lgtm!please resolve the conflicts and this pull request will be merged soon.
src/flag_gems/ops/dot.py
Outdated
| dot_kernel_2[grid_2](mid, out, mid_size, block_mid) | ||
|
|
||
| else: | ||
| block_size = triton.next_power_of_2(math.ceil(N)) |
There was a problem hiding this comment.
math.ceil is useless here.
tests/test_reduction_ops.py
Outdated
| inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) | ||
| inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) | ||
| ref_inp1 = to_reference(inp1, False) | ||
| ref_inp2 = to_reference(inp2, False) |
There was a problem hiding this comment.
it's recommended to set parameter upcast as True, which indicates higher precision of reference.








PR Category
Operator
Type of Change
Add new operator
Description
Implement dot operator, support Float32, Float16,BFloat16。
The operator implementation is to split the dot operator into two steps, the first step implementing elementwise level multiplication, and the second step implementing summation.
At present, in order to ensure accuracy requirements, the intermediate results are saved as float32 type in the first step.
Issue
#394
Progress
Performance
correctness

performance


