@@ -16,13 +16,15 @@ using StaticArrays
16
16
using Adapt
17
17
18
18
"""
19
- @kernel function f(args) end
19
+ @kernel [N] function f(args) end
20
20
21
21
Takes a function definition and generates a [`Kernel`](@ref) constructor from it.
22
22
The enclosed function is allowed to contain kernel language constructs.
23
23
In order to call it the kernel has first to be specialized on the backend
24
24
and then invoked on the arguments.
25
25
26
+ The optional `N` parameter can be used to fix the number of dimensions used for the ndrange.
27
+
26
28
# Kernel language
27
29
28
30
- [`@Const`](@ref)
@@ -54,7 +56,7 @@ macro kernel(expr)
54
56
end
55
57
56
58
"""
57
- @kernel config function f(args) end
59
+ @kernel [N] config function f(args) end
58
60
59
61
This allows for two different configurations:
60
62
@@ -584,17 +586,17 @@ in a workgroup.
584
586
```
585
587
As well as the on-device functionality.
586
588
"""
587
- struct Kernel{Backend, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
589
+ struct Kernel{Backend, N, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
588
590
backend:: Backend
589
591
f:: Fun
590
592
end
591
593
592
- function Base. similar (kernel:: Kernel{D, WS, ND} , f:: F ) where {D, WS, ND, F}
593
- Kernel {D, WS, ND, F} (kernel. backend, f)
594
+ function Base. similar (kernel:: Kernel{D, N, WS, ND} , f:: F ) where {D, N , WS, ND, F}
595
+ Kernel {D, N, WS, ND, F} (kernel. backend, f)
594
596
end
595
597
596
- workgroupsize (:: Kernel{D, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
597
- ndrange (:: Kernel{D, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
598
+ workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
599
+ ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
598
600
backend (kernel:: Kernel ) = kernel. backend
599
601
600
602
"""
@@ -657,8 +659,8 @@ Partition a kernel for the given ndrange and workgroupsize.
657
659
return iterspace, dynamic
658
660
end
659
661
660
- function construct (backend:: Backend , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , S <: _Size , NDRange <: _Size , XPUName}
661
- return Kernel {Backend, S, NDRange, XPUName} (backend, xpu_name)
662
+ function construct (backend:: Backend , :: Val{N} , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N , S <: _Size , NDRange <: _Size , XPUName}
663
+ return Kernel {Backend, N, S, NDRange, XPUName} (backend, xpu_name)
662
664
end
663
665
664
666
# ##
0 commit comments