Skip to content

Commit 282a4e1

Browse files
authored
Updates for StaticArrayInterface (#469)
1 parent 57c7ffa commit 282a4e1

File tree

11 files changed

+57
-62
lines changed

11 files changed

+57
-62
lines changed

Project.toml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.150"
4+
version = "0.12.151"
55

66
[weakdeps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -15,8 +15,6 @@ SpecialFunctionsExt = "SpecialFunctions"
1515
[deps]
1616
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1717
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
18-
ArrayInterfaceOffsetArrays = "015c0d05-e682-4f19-8f0a-679ce4c54826"
19-
ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
2018
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
2119
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2220
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"
@@ -33,15 +31,14 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
3331
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
3432
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3533
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
34+
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
3635
ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"
3736
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3837
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
3938

4039
[compat]
41-
ArrayInterface = "6"
40+
ArrayInterface = "7"
4241
ArrayInterfaceCore = "0.1.5"
43-
ArrayInterfaceOffsetArrays = "0.1.2"
44-
ArrayInterfaceStaticArrays = "0.1.2"
4542
CPUSummary = "0.1.3 - 0.1.8, 0.1.11, 0.2.1"
4643
ChainRulesCore = "1"
4744
CloseOpenIntervals = "0.1.10"
@@ -56,7 +53,8 @@ SIMDTypes = "0.1"
5653
SLEEFPirates = "0.6.23"
5754
SnoopPrecompile = "1"
5855
SpecialFunctions = "1, 2"
59-
Static = "0.7, 0.8"
56+
Static = "0.8.4"
57+
StaticArrayInterface = "1"
6058
ThreadingUtilities = "0.5"
6159
UnPack = "1"
6260
VectorizationBase = "0.21.53"

src/LoopVectorization.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ end
88
using ArrayInterfaceCore: UpTri, LoTri
99
using Static: StaticInt, gt, static, Zero, One, reduce_tup
1010
using VectorizationBase,
11-
SLEEFPirates,
12-
UnPack,
13-
OffsetArrays,
14-
ArrayInterfaceOffsetArrays,
15-
ArrayInterfaceStaticArrays
11+
SLEEFPirates, UnPack, OffsetArrays, StaticArrayInterface
12+
const ArrayInterface = StaticArrayInterface
1613
using LayoutPointers:
1714
AbstractStridedPointer,
1815
StridedPointer,
@@ -155,18 +152,17 @@ using SLEEFPirates:
155152
sincos_fast,
156153
tan_fast
157154

158-
using ArrayInterface
159-
using ArrayInterface:
155+
using StaticArrayInterface:
160156
OptionallyStaticUnitRange,
161157
OptionallyStaticRange,
162158
StaticBool,
163159
True,
164160
False,
165161
indices,
166-
strides,
162+
static_strides,
167163
offsets,
168-
size,
169-
axes,
164+
static_size,
165+
static_axes,
170166
StrideIndex
171167
using CloseOpenIntervals: AbstractCloseOpen, CloseOpen#, SafeCloseOpen
172168
# @static if VERSION ≥ v"1.6.0-rc1" #TODO: delete `else` when dropping 1.5 support

src/broadcast.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ end
4444
@inline ArrayInterface.parent_type(
4545
::Type{LowDimArray{D,T,N,A}}
4646
) where {T,D,N,A} = A
47-
@inline Base.strides(A::LowDimArray) = map(Int, strides(A))
47+
@inline Base.strides(A::LowDimArray) = map(Int, static_strides(A))
4848
@inline ArrayInterface.device(::LowDimArray) = ArrayInterface.CPUPointer()
49-
@generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N}
49+
@generated function ArrayInterface.static_size(
50+
A::LowDimArray{D,T,N}
51+
) where {D,T,N}
5052
t = Expr(:tuple)
5153
for n 1:N
5254
if n > length(D) || D[n]
@@ -105,11 +107,13 @@ end
105107
@inline forbroadcast(A) = A
106108
# @inline forbroadcast(A::Adjoint) = forbroadcast(parent(A))
107109
# @inline forbroadcast(A::Transpose) = forbroadcast(parent(A))
108-
@inline function ArrayInterface.strides(A::Union{LowDimArray,ForBroadcast})
110+
@inline function ArrayInterface.static_strides(
111+
A::Union{LowDimArray,ForBroadcast}
112+
)
109113
B = parent(A)
110114
_strides(
111-
size(A),
112-
strides(B),
115+
static_size(A),
116+
static_strides(B),
113117
VectorizationBase.val_stride_rank(B),
114118
VectorizationBase.val_dense_dims(B)
115119
)
@@ -145,10 +149,10 @@ end
145149
) where {D,T,N,A}
146150
_lowdimfilter(Val(D), ArrayInterface.dense_dims(A))
147151
end
148-
@inline function ArrayInterface.strides(
152+
@inline function ArrayInterface.static_strides(
149153
fb::LowDimArrayForBroadcast{D}
150154
) where {D}
151-
_lowdimfilter(Val(D), strides(parent(fb)))
155+
_lowdimfilter(Val(D), static_strides(parent(fb)))
152156
end
153157
@inline function ArrayInterface.offsets(
154158
fb::LowDimArrayForBroadcast{D}
@@ -225,11 +229,9 @@ function _strides_expr(
225229
sₙ_value::Int = 0
226230
for n Nrange
227231
xₙ_type = x[n]
228-
# xₙ_type = typeof(x).parameters[n]
229232
xₙ_static = xₙ_type <: StaticInt
230233
xₙ_value::Int = xₙ_static ? (xₙ_type.parameters[1])::Int : 0
231234
s_type = s[n]
232-
# s_type = typeof(s).parameters[n]
233235
sₙ_static = s_type <: StaticInt
234236
if sₙ_static
235237
sₙ_value = s_type.parameters[1]
@@ -365,7 +367,7 @@ function add_broadcast!(
365367
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
366368
pushprepreamble!(
367369
ls,
368-
Expr(:(=), Klen, Expr(:call, getfield, Expr(:call, :size, mB), 1))
370+
Expr(:(=), Klen, Expr(:call, getfield, Expr(:call, :static_size, mB), 1))
369371
)
370372
pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen)))
371373
k = gensym!(ls, "k")
@@ -587,7 +589,7 @@ function add_broadcast_loops!(
587589
destsym::Symbol
588590
)
589591
axes_tuple = Expr(:tuple)
590-
pushpreamble!(ls, Expr(:(=), axes_tuple, Expr(:call, :axes, destsym)))
592+
pushpreamble!(ls, Expr(:(=), axes_tuple, Expr(:call, :static_axes, destsym)))
591593
for itersym loopsyms
592594
Nrange = gensym!(ls, "N")
593595
Nlower = gensym!(ls, "N")

src/condense_loopset.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,11 @@ val(x) = Expr(:call, Expr(:curly, :Val, x))
397397
p, li = VectorizationBase.tdot(
398398
x,
399399
(vsub_nsw(getfield(i, 1), one($I)),),
400-
strides(x)
400+
static_strides(x)
401401
)
402402
ptr = gep(p, li)
403403
si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}(
404-
(getfield(strides(x), $ri),),
404+
(getfield(static_strides(x), $ri),),
405405
(Zero(),)
406406
)
407407
stridedpointer(ptr, si, StaticInt{$(B === 1 ? 1 : 0)}())
@@ -415,7 +415,7 @@ end
415415
quote
416416
$(Expr(:meta, :inline))
417417
si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}(
418-
(getfield(strides(x), $ri),),
418+
(getfield(static_strides(x), $ri),),
419419
(getfield(offsets(x), $ri),)
420420
)
421421
stridedpointer(pointer(x), si, StaticInt{$(B == 1 ? 1 : 0)}())

src/modeling/graphs.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -995,10 +995,12 @@ function makestatic!(expr)
995995
if ex isa Int
996996
expr.args[i] = staticexpr(ex)
997997
elseif ex isa Symbol
998-
if ex === :length
999-
expr.args[i] = GlobalRef(ArrayInterface, :static_length)
1000-
elseif Base.sym_in(ex, (:axes, :size))
1001-
expr.args[i] = GlobalRef(ArrayInterface, ex)
998+
j = findfirst(==(ex), (:axes, :size, :length))
999+
if j !== nothing
1000+
expr.args[i] = GlobalRef(
1001+
ArrayInterface,
1002+
(:static_axes, :static_size, :static_length)[j]
1003+
)
10021004
end
10031005
elseif ex isa Expr
10041006
makestatic!(ex)
@@ -1215,7 +1217,7 @@ function indices_loop!(ls::LoopSet, r::Expr, itersym::Symbol)::Loop
12151217
axsym,
12161218
Expr(
12171219
:call,
1218-
GlobalRef(ArrayInterface, :axes),
1220+
GlobalRef(ArrayInterface, :static_axes),
12191221
a_s,
12201222
staticexpr(dims::Int)
12211223
)
@@ -1280,7 +1282,7 @@ function indices_loop!(ls::LoopSet, r::Expr, itersym::Symbol)::Loop
12801282
axsym,
12811283
Expr(
12821284
:call,
1283-
GlobalRef(ArrayInterface, :axes),
1285+
GlobalRef(ArrayInterface, :static_axes),
12841286
a_s,
12851287
staticexpr(mdim)
12861288
)
@@ -1351,7 +1353,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
13511353
)
13521354
indices_loop!(ls, r, itersym)
13531355
else
1354-
(f === :axes) && (r.args[1] = lv(:axes))
1356+
(f === :axes) && (r.args[1] = lv(:static_axes))
13551357
misc_loop!(ls, r, itersym, (f === :eachindex) | (f === :axes))
13561358
end
13571359
elseif isa(r, Symbol)

src/reconstruct_loopset.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,10 @@ function _add_mref!(
359359
offsets = gensym(:offsets)
360360
strides = gensym(:strides)
361361
pushpreamble!(ls, Expr(:(=), offsets, Expr(:call, lv(:offsets), tmpsp)))
362-
pushpreamble!(ls, Expr(:(=), strides, Expr(:call, lv(:strides), tmpsp)))
362+
pushpreamble!(
363+
ls,
364+
Expr(:(=), strides, Expr(:call, lv(:static_strides), tmpsp))
365+
)
363366
for (i, p) enumerate(sp)
364367
push!(strd_tup.args, Expr(:call, gf, strides, p, false))
365368
push!(offsets_tup.args, Expr(:call, gf, offsets, p, false))

src/simdfunctionals/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ for (op, init) in zip((:+, :max, :min), (:zero, :typemin, :typemax))
134134
Base.Cartesian.@nif 5 d -> (d <= ndims(arg) && dims == d) d -> begin
135135
Rpre = CartesianIndices(ntuple(i -> axes_arg[i], d - 1))
136136
Rpost = CartesianIndices(ntuple(i -> axes_arg[i+d], ndims(arg) - d))
137-
_vreduce_dims!(out, $op, Rpre, 1:size(arg, dims), Rpost, arg)
137+
_vreduce_dims!(out, $op, Rpre, static_axes(arg, dims), Rpost, arg)
138138
end d -> begin
139139
Rpre = CartesianIndices(axes_arg[1:dims-1])
140140
Rpost = CartesianIndices(axes_arg[dims+1:end])
141-
_vreduce_dims!(out, $op, Rpre, 1:size(arg, dims), Rpost, arg)
141+
_vreduce_dims!(out, $op, Rpre, static_axes(arg, dims), Rpost, arg)
142142
end
143143
end
144144

test/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function test_broadcast(::Type{T}) where {T}
1313
b = rand(R, 99, 99, 1)
1414
bl = LowDimArray{(true, true, false)}(b)
1515
@test size(bl) == size(b)
16-
@test LoopVectorization.ArrayInterface.size(bl) ===
16+
@test LoopVectorization.static_size(bl) ===
1717
(size(b, 1), size(b, 2), LoopVectorization.StaticInt(1))
1818

1919
br = reshape(b, (99, 99))
@@ -29,7 +29,7 @@ function test_broadcast(::Type{T}) where {T}
2929
br = reshape(b, (99, 1, 99))
3030
bl = LowDimArray{(true, false, true)}(br)
3131
@test size(bl) == size(br)
32-
@test LoopVectorization.ArrayInterface.size(bl) ===
32+
@test LoopVectorization.static_size(bl) ===
3333
(size(br, 1), LoopVectorization.StaticInt(1), size(br, 3))
3434
@. c1 = a + br
3535
fill!(c2, 99999)
@@ -41,7 +41,7 @@ function test_broadcast(::Type{T}) where {T}
4141
br = reshape(b, (1, 99, 99))
4242
bl = LowDimArray{(false,)}(br)
4343
@test size(bl) == size(br)
44-
@test LoopVectorization.ArrayInterface.size(bl) ===
44+
@test LoopVectorization.static_size(bl) ===
4545
(LoopVectorization.StaticInt(1), size(br, 2), size(br, 3))
4646
@. c1 = a + br
4747
fill!(c2, 99999)

test/gemm.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@
609609
return C
610610
end
611611
function dense!(f::F, C, A, B) where {F}
612-
Kp1 = LoopVectorization.size(A, LoopVectorization.StaticInt(2))
612+
Kp1 = LoopVectorization.static_size(A, LoopVectorization.StaticInt(2))
613613
K = Kp1 - LoopVectorization.StaticInt(1)
614614
@turbo for n indices((B, C), 2), m indices((A, C), 1)
615615
Cmn = zero(eltype(C))
@@ -733,7 +733,7 @@
733733
Base.@propagate_inbounds Base.setindex!(A::TestSizedMatrix, v, i::Int, j::Int) =
734734
setindex!(parent(A), v, i + 1, j + 1)
735735
Base.size(::TestSizedMatrix{M,N}) where {M,N} = (M, N)
736-
LoopVectorization.ArrayInterface.size(::TestSizedMatrix{M,N}) where {M,N} =
736+
LoopVectorization.static_size(::TestSizedMatrix{M,N}) where {M,N} =
737737
(LoopVectorization.StaticInt{M}(), LoopVectorization.StaticInt{N}())
738738
function Base.axes(::TestSizedMatrix{M,N}) where {M,N}
739739
(
@@ -757,7 +757,7 @@
757757
end
758758
Base.unsafe_convert(::Type{Ptr{T}}, A::TestSizedMatrix{M,N,T}) where {M,N,T} =
759759
pointer(A.data)
760-
LoopVectorization.ArrayInterface.strides(::TestSizedMatrix{M}) where {M} =
760+
LoopVectorization.static_strides(::TestSizedMatrix{M}) where {M} =
761761
(LoopVectorization.StaticInt{1}(), LoopVectorization.StaticInt{M}())
762762
LoopVectorization.ArrayInterface.contiguous_axis(::Type{<:TestSizedMatrix}) =
763763
LoopVectorization.One()
@@ -771,17 +771,7 @@
771771
LoopVectorization.ArrayInterface.dense_dims(
772772
::Type{TestSizedMatrix{M,N,T}},
773773
) where {M,N,T} = LoopVectorization.ArrayInterface.dense_dims(Matrix{T})
774-
# struct ZeroInitializedArray{T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}
775-
# data::A
776-
# end
777-
# Base.size(A::ZeroInitializedArray) = size(A.data)
778-
# Base.length(A::ZeroInitializedArray) = length(A.data)
779-
# Base.axes(A::ZeroInitializedArray, i) = axes(A.data, i)
780-
# @inline Base.getindex(A::ZeroInitializedArray{T}) where {T} = zero(T)
781-
# Base.@propagate_inbounds Base.setindex!(A::ZeroInitializedArray, v, i...) = setindex!(A.data, v, i...)
782-
# function LoopVectorization.VectorizationBase.stridedpointer(A::ZeroInitializedArray)
783-
# LoopVectorization.VectorizationBase.ZeroInitializedStridedPointer(LoopVectorization.VectorizationBase.stridedpointer(A.data))
784-
# end
774+
785775

786776
@testset "Matmuls" begin
787777
for T (Float32, Float64, Int32, Int64)

test/offsetarrays.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using LoopVectorization, ArrayInterface, OffsetArrays, Test
1+
using LoopVectorization, OffsetArrays, Test
2+
using LoopVectorization: ArrayInterface
23
using LoopVectorization: StaticInt
34
# T = Float64; r = -1:1;
45
# T = Float32; r = -1:1;
@@ -109,10 +110,12 @@ using LoopVectorization: StaticInt
109110
ArrayInterface.contiguous_batch_size(::Type{<:SizedOffsetMatrix}) = ArrayInterface.Zero()
110111
ArrayInterface.stride_rank(::Type{<:SizedOffsetMatrix}) =
111112
(ArrayInterface.StaticInt(1), ArrayInterface.StaticInt(2))
112-
function ArrayInterface.strides(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC}
113+
function LoopVectorization.static_strides(
114+
::SizedOffsetMatrix{T,LR,UR,LC,UC},
115+
) where {T,LR,UR,LC,UC}
113116
(StaticInt{1}(), (StaticInt{UR}() - StaticInt{LR}() + StaticInt{1}()))
114117
end
115-
ArrayInterface.offsets(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} =
118+
ArrayInterface.offsets(::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} =
116119
(StaticInt{LR}(), StaticInt{LC}())
117120
ArrayInterface.dense_dims(::Type{<:SizedOffsetMatrix{T}}) where {T} =
118121
ArrayInterface.dense_dims(Matrix{T})

test/parsing_inputs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using LoopVectorization, Test, ArrayInterface
1+
using LoopVectorization, Test
2+
using LoopVectorization: ArrayInterface
23
using LoopVectorization: check_inputs!
34

45
# macros for generate loops whose body is not a block

0 commit comments

Comments
 (0)