Skip to content

Commit 00bcec3

Browse files
committed
Optimize expand by using SignedMultiplivativeInverse
1 parent 419481c commit 00bcec3

File tree

5 files changed

+104
-14
lines changed

5 files changed

+104
-14
lines changed

src/nditeration.jl

+99-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,65 @@
11
module NDIteration
22

3+
import Base.MultiplicativeInverses: SignedMultiplicativeInverse
4+
5+
# CartesianIndex uses Int instead of Int32
6+
7+
@eval EmptySMI() = $(Expr(:new, SignedMultiplicativeInverse{Int32}, Int32(0), typemax(Int32), 0 % Int8, 0 % UInt8))
8+
SMI(i) = i == 0 ? EmptySMI() : SignedMultiplicativeInverse{Int32}(i)
9+
10+
struct FastCartesianIndices{N} <: AbstractArray{CartesianIndex{N}, N}
11+
inverses::NTuple{N, SignedMultiplicativeInverse{Int32}}
12+
end
13+
14+
function FastCartesianIndices(indices::NTuple{N}) where {N}
15+
inverses = map(i -> SMI(Int32(i)), indices)
16+
FastCartesianIndices(inverses)
17+
end
18+
19+
function Base.size(FCI::FastCartesianIndices{N}) where {N}
20+
ntuple(Val(N)) do I
21+
FCI.inverses[I].divisor
22+
end
23+
end
24+
25+
@inline function Base.getindex(::FastCartesianIndices{0})
26+
return CartesianIndex()
27+
end
28+
29+
@inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where {N}
30+
@boundscheck checkbounds(iter, I...)
31+
index = map(iter.inverses, I) do inv, i
32+
@inbounds getindex(Base.OneTo(inv.divisor), i%Int32)
33+
end
34+
CartesianIndex(index)
35+
end
36+
37+
_ind2sub_recuse(::Tuple{}, ind) = (ind + 1,)
38+
function _ind2sub_recurse(indslast::NTuple{1}, ind)
39+
Base.@_inline_meta
40+
(_lookup(ind, indslast[1]),)
41+
end
42+
43+
function _ind2sub_recurse(inds, ind)
44+
Base.@_inline_meta
45+
assume(ind >= 0)
46+
inv = inds[1]
47+
indnext, f, l = _div(ind, inv)
48+
(ind - l * indnext + f, _ind2sub_recurse(Base.tail(inds), indnext)...)
49+
end
50+
51+
_lookup(ind, inv::SignedMultiplicativeInverse) = ind + 1
52+
function _div(ind, inv::SignedMultiplicativeInverse)
53+
# inv.divisor == 0 && throw(DivideError())
54+
assume(ind >= 0)
55+
div(ind % Int32, inv), 1, inv.divisor
56+
end
57+
58+
function Base._ind2sub(inv::FastCartesianIndices, ind)
59+
Base.@_inline_meta
60+
_ind2sub_recurse(inv.inverses, ind - 1)
61+
end
62+
363
export _Size, StaticSize, DynamicSize, get
464
export NDRange, blocks, workitems, expand
565
export DynamicCheck, NoDynamicCheck
@@ -50,18 +110,32 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
50110
blocks::DynamicBlock
51111
workitems::DynamicWorkitems
52112

53-
function NDRange{N, B, W}() where {N, B, W}
54-
new{N, B, W, Nothing, Nothing}(nothing, nothing)
55-
end
56-
57-
function NDRange{N, B, W}(blocks, workitems) where {N, B, W}
113+
function NDRange{N, B, W}(blocks::Union{Nothing, FastCartesianIndices{N}}, workitems::Union{Nothing, FastCartesianIndices{N}}) where {N, B, W}
114+
@assert B <: _Size
115+
@assert W <: _Size
58116
new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
59117
end
60118
end
61119

62-
@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::CartesianIndices{N}
120+
function NDRange{N, B, W}() where {N, B, W}
121+
NDRange{N, B, W}(nothing, nothing)
122+
end
123+
124+
function NDRange{N, B, W}(blocks::CartesianIndices, workitems::CartesianIndices) where {N, B, W}
125+
return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), FastCartesianIndices(size(workitems)))
126+
end
127+
128+
function NDRange{N, B, W}(blocks::Nothing, workitems::CartesianIndices) where {N, B, W}
129+
return NDRange{N, B, W}(blocks, FastCartesianIndices(size(workitems)))
130+
end
131+
132+
function NDRange{N, B, W}(blocks::CartesianIndices, workitems::Nothing) where {N, B, W}
133+
return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), workitems)
134+
end
135+
136+
@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::FastCartesianIndices{N}
63137
@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: StaticSize} = CartesianIndices(get(W))::CartesianIndices{N}
64-
@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::CartesianIndices{N}
138+
@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::FastCartesianIndices{N}
65139
@inline blocks(range::NDRange{N, B}) where {N, B <: StaticSize} = CartesianIndices(get(B))::CartesianIndices{N}
66140

67141
import Base.iterate
@@ -80,8 +154,8 @@ Base.length(range::NDRange) = length(blocks(range))
80154
CartesianIndex(nI)
81155
end
82156

83-
Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer)
84-
expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
157+
Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::Integer) where {N}
158+
return expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
85159
end
86160

87161
Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N}
@@ -126,4 +200,20 @@ needs to perform dynamic bounds-checking.
126200
end
127201
end
128202

203+
204+
205+
"""
206+
assume(cond::Bool)
207+
Assume that the condition `cond` is true. This is a hint to the compiler, possibly enabling
208+
it to optimize more aggressively.
209+
"""
210+
@inline assume(cond::Bool) = Base.llvmcall(("""
211+
declare void @llvm.assume(i1)
212+
define void @entry(i8) #0 {
213+
%cond = icmp eq i8 %0, 1
214+
call void @llvm.assume(i1 %cond)
215+
ret void
216+
}
217+
attributes #0 = { alwaysinline }""", "entry"),
218+
Nothing, Tuple{Bool}, cond)
129219
end #module

test/compiler.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ end
3939

4040
function compiler_testsuite(backend, ArrayT)
4141
kernel = index(CPU(), DynamicSize(), DynamicSize())
42-
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
42+
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}()
4343
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck()))
4444
@test KernelAbstractions.__index_Global_NTuple(ctx, CartesianIndex(1)) == (1,)
4545

test/localmem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Test
88
end
99
I = @index(Global, Linear)
1010
i = @index(Local, Linear)
11-
lmem = @localmem Int (N,) # Ok iff groupsize is static
11+
lmem = @localmem Int (N,) # Ok iff groupsize is static
1212
@inbounds begin
1313
lmem[i] = i
1414
@synchronize
@@ -23,7 +23,7 @@ end
2323
end
2424
I = @index(Global, Linear)
2525
i = @index(Local, Linear)
26-
lmem = @localmem Int (N,) # Ok iff groupsize is static
26+
lmem = @localmem Int (N,) # Ok iff groupsize is static
2727
@inbounds begin
2828
lmem[i] = i + 3
2929
for j in 1:2

test/private.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function private_testsuite(backend, ArrayT)
7373

7474
A = ArrayT{Int}(undef, 64, 64)
7575
A .= 1
76-
forloop(backend())(A, Val(size(A, 2)), ndrange = size(A, 1), workgroupsize = size(A, 1))
76+
forloop(backend(), size(A, 1))(A, Val(size(A, 2)), ndrange = size(A, 1), workgroupsize = size(A, 1))
7777
synchronize(backend())
7878
@test all(Array(A)[:, 1] .== 64)
7979
@test all(Array(A)[:, 2:end] .== 1)

test/test.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
154154
@conditional_testset "Const" skip_tests begin
155155
let kernel = constarg(Backend(), 8, (1024,))
156156
# this is poking at internals
157-
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
157+
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}()
158158
ctx = if Backend == CPU
159159
KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(NoDynamicCheck()))
160160
else

0 commit comments

Comments
 (0)