Skip to content

Commit 807675d

Browse files
authored
Use extensions for weakdeps (#464)
* Use extensions for weakdeps * LSP added "using Base: get_extension", which is probably a bad idea * Fix extension syntax * Fixes * module * Fix zygote tests
1 parent 21b86bb commit 807675d

File tree

5 files changed

+88
-54
lines changed

5 files changed

+88
-54
lines changed

Project.toml

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.148"
4+
version = "0.12.149"
5+
6+
[weakdeps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
10+
11+
[extensions]
12+
ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"]
13+
SpecialFunctionsExt = "SpecialFunctions"
514

615
[deps]
716
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/simdfunctionals/vmap_grad_rrule.jl ext/ForwardDiffExt.jl

+68-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,60 @@
1+
module ForwardDiffExt
2+
import ForwardDiff, ChainRulesCore
3+
using SIMDDualNumbers, LoopVectorization
4+
using LoopVectorization:
5+
AbstractSIMD,
6+
AbstractStridedPointer,
7+
relu,
8+
vmap,
9+
VectorizationBase,
10+
vmapt,
11+
vmapnt,
12+
vmapntt,
13+
MM,
14+
StaticInt,
15+
vadd_nw,
16+
vsub_nsw,
17+
vload,
18+
mask,
19+
vfnmadd_fast,
20+
mul_fast
21+
using VectorizationBase: zero_offsets
122

2-
import .ChainRulesCore
23+
@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A}
24+
res = Expr(:tuple)
25+
q = Expr(:block, Expr(:meta, :inline))
26+
for a 1:A
27+
v_a = Symbol(:v_, a)
28+
push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a)))
29+
partials = Expr(:tuple)
30+
for i 1:A
31+
push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a))
32+
end
33+
push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials))))
34+
end
35+
push!(q.args, res)
36+
q
37+
end
38+
@generated function dual_store!(
39+
∂p::Tuple{Vararg{AbstractStridedPointer,A}},
40+
p::AbstractStridedPointer,
41+
∂v,
42+
im::Vararg{Any,N}
43+
) where {A,N}
44+
quote
45+
$(Expr(:meta, :inline))
46+
v = ∂v.value
47+
= ∂v.partials
48+
Base.Cartesian.@nextract $N im im
49+
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! p v im # store
50+
Base.Cartesian.@nexprs $A a -> begin # for each of `A` partials
51+
∂p_a = ∂p[a]
52+
∂_a = ∂[a]
53+
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! ∂p_a ∂_a im # store
54+
end
55+
nothing
56+
end
57+
end
358

459
if isdefined(ChainRulesCore, :ZeroTangent)
560
const ChainRulesZero = ChainRulesCore.ZeroTangent
@@ -38,32 +93,33 @@ function ∂vmap_singlethread!(
3893
args::Vararg{DenseArray{<:Base.HWReal},A}
3994
) where {F,T<:Base.HWReal,A}
4095
N = length(y)
41-
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
42-
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
43-
ptr∂y = VectorizationBase.zero_offsets.(stridedpointer.(∂y))
44-
96+
ptry = zero_offsets(stridedpointer(y))
97+
ptrargs = map(zero_offsets, map(stridedpointer, args))
98+
ptr∂y = map(zero_offsets, map(stridedpointer, ∂y))
4599
i = 0
46100
V = VectorizationBase.pick_vector_width(T)
47101
W = Int(V)
48-
st = VectorizationBase.static_sizeof(T)
49-
zero_index = MM{W}(StaticInt(0), st)
50102
while i < vsub_nsw(N, ((W << 2) - 1))
51103
index = VectorizationBase.Unroll{1,W,4,1,W,zero(UInt)}((i,))
52-
v = f(init_dual(vload.(ptrargs, index))...)
104+
v = f(init_dual(map(Base.Fix2(vload, index), ptrargs))...)
53105
dual_store!(ptr∂y, ptry, v, index)
54106
i = vadd_nw(i, 4W)
55107
end
56108
while i < vsub_nsw(N, (W - 1))
57-
vᵣ = f(init_dual(vload.(ptrargs, ((MM{W}(i),),)))...)
109+
loader = Base.Fix2(vload, (MM{W}(i),))
110+
vᵣ = f(init_dual(map(loader, ptrargs))...)
58111
dual_store!(ptr∂y, ptry, vᵣ, (MM{W}(i),))
59112
i = vadd_nw(i, W)
60113
end
61114
if i < N
62115
m = mask(T, N & (W - 1))
116+
mloader = let i = i, m = m
117+
p -> vload(p, (MM{W}(i),), m)
118+
end
63119
dual_store!(
64120
ptr∂y,
65121
ptry,
66-
f(init_dual(vload.(ptrargs, ((MM{W}(i),),), m))...),
122+
f(init_dual(map(mloader, ptrargs))...),
67123
(MM{W}(i),),
68124
m
69125
)
@@ -109,6 +165,7 @@ for f in (:vmapt, :vmapnt, :vmapntt)
109165
f::F,
110166
args::Vararg{Any,K}
111167
) where {F,K}
112-
ChainRulesCore.rrule(typeof(vmap), f, args...)
168+
ChainRulesCore.rrule(typeof($vmap), f, args...)
113169
end
114170
end
171+
end

ext/SpecialFunctionsExt.jl

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
module SpecialFunctionsExt
2+
using SpecialFunctions
3+
using LoopVectorization: VectorizationBase
4+
using LoopVectorization: AbstractSIMD
5+
@inline SpecialFunctions.erf(x::AbstractSIMD) = VectorizationBase.verf(float(x))
6+
end

src/LoopVectorization.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ include("precompile.jl")
262262

263263
# import ChainRulesCore, ForwardDiff
264264
# include("vmap_grad.jl")
265-
using ChainRulesCore, ForwardDiff, SpecialFunctions
266-
include("simdfunctionals/vmap_grad_rrule.jl")
267-
include("simdfunctionals/vmap_grad_forwarddiff.jl")
268-
@inline SpecialFunctions.erf(x::AbstractSIMD) = VectorizationBase.verf(float(x))
265+
if !isdefined(Base, :get_extension)
266+
include("../ext/ForwardDiffExt.jl")
267+
include("../ext/SpecialFunctionsExt.jl")
268+
end
269269

270270
end # module

src/simdfunctionals/vmap_grad_forwarddiff.jl

-38
This file was deleted.

0 commit comments

Comments
 (0)