Skip to content

Commit 1ac546f

Browse files
authored
avoid boxing when @kernel is used as a closure (#627)
This allows uses of `@kernel` inside of functions without running into JuliaLang/julia#53295.
1 parent 31dc584 commit 1ac546f

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/macros.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ function __kernel(expr, force_inbounds = false, unsafe_indices = false)
3434
gpu_function = combinedef(def_gpu)
3535

3636
# create constructor functions
37+
_name = Symbol(:_, name)
3738
constructors = quote
3839
if $(name isa Symbol ? :(!@isdefined($name)) : true)
39-
Core.@__doc__ $name(dev) = $name(dev, $DynamicSize(), $DynamicSize())
40-
$name(dev, size) = $name(dev, $StaticSize(size), $DynamicSize())
41-
$name(dev, size, range) = $name(dev, $StaticSize(size), $StaticSize(range))
42-
function $name(dev::Dev, sz::S, range::NDRange) where {Dev, S <: $_Size, NDRange <: $_Size}
40+
function $_name(dev::Dev, sz::S, range::NDRange) where {Dev, S <: $_Size, NDRange <: $_Size}
4341
return $construct(dev, sz, range, $gpu_name)
4442
end
43+
Core.@__doc__ $name(dev) = $_name(dev, $DynamicSize(), $DynamicSize())
44+
$name(dev, size) = $_name(dev, $StaticSize(size), $DynamicSize())
45+
$name(dev, size, range) = $_name(dev, $StaticSize(size), $StaticSize(range))
46+
$name(dev, size::$_Size, range::$_Size) = $_name(dev, size, range)
4547
end
4648
end
4749

test/test.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,5 +307,20 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
307307
@test all(a -> a == 0, @view(Ah[(length(A) ÷ 2 + 1):end]))
308308
end
309309

310+
@testset "`@kernel` as a closure" begin
311+
function foo()
312+
@kernel function kernel(A)
313+
i = @index(Global)
314+
A[i] = 1
315+
end
316+
return kernel
317+
end
318+
ftypes = fieldtypes(typeof(foo()))
319+
@test !any(T -> T <: Core.Box, ftypes)
320+
@test all(ftypes) do T
321+
!any(T -> T <: Core.Box, fieldtypes(T))
322+
end
323+
end
324+
310325
return
311326
end

0 commit comments

Comments
 (0)