|
| 1 | +module ForwardDiffExt |
| 2 | +import ForwardDiff, ChainRulesCore |
| 3 | +using SIMDDualNumbers, LoopVectorization |
| 4 | +using LoopVectorization: |
| 5 | + AbstractSIMD, |
| 6 | + AbstractStridedPointer, |
| 7 | + relu, |
| 8 | + vmap, |
| 9 | + VectorizationBase, |
| 10 | + vmapt, |
| 11 | + vmapnt, |
| 12 | + vmapntt, |
| 13 | + MM, |
| 14 | + StaticInt, |
| 15 | + vadd_nw, |
| 16 | + vsub_nsw, |
| 17 | + vload, |
| 18 | + mask, |
| 19 | + vfnmadd_fast, |
| 20 | + mul_fast |
| 21 | +using VectorizationBase: zero_offsets |
1 | 22 |
|
2 |
| -import .ChainRulesCore |
| 23 | +@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A} |
| 24 | + res = Expr(:tuple) |
| 25 | + q = Expr(:block, Expr(:meta, :inline)) |
| 26 | + for a ∈ 1:A |
| 27 | + v_a = Symbol(:v_, a) |
| 28 | + push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a))) |
| 29 | + partials = Expr(:tuple) |
| 30 | + for i ∈ 1:A |
| 31 | + push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a)) |
| 32 | + end |
| 33 | + push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials)))) |
| 34 | + end |
| 35 | + push!(q.args, res) |
| 36 | + q |
| 37 | +end |
| 38 | +@generated function dual_store!( |
| 39 | + ∂p::Tuple{Vararg{AbstractStridedPointer,A}}, |
| 40 | + p::AbstractStridedPointer, |
| 41 | + ∂v, |
| 42 | + im::Vararg{Any,N} |
| 43 | +) where {A,N} |
| 44 | + quote |
| 45 | + $(Expr(:meta, :inline)) |
| 46 | + v = ∂v.value |
| 47 | + ∂ = ∂v.partials |
| 48 | + Base.Cartesian.@nextract $N im im |
| 49 | + Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! p v im # store |
| 50 | + Base.Cartesian.@nexprs $A a -> begin # for each of `A` partials |
| 51 | + ∂p_a = ∂p[a] |
| 52 | + ∂_a = ∂[a] |
| 53 | + Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! ∂p_a ∂_a im # store |
| 54 | + end |
| 55 | + nothing |
| 56 | + end |
| 57 | +end |
3 | 58 |
|
4 | 59 | if isdefined(ChainRulesCore, :ZeroTangent)
|
5 | 60 | const ChainRulesZero = ChainRulesCore.ZeroTangent
|
@@ -38,32 +93,33 @@ function ∂vmap_singlethread!(
|
38 | 93 | args::Vararg{DenseArray{<:Base.HWReal},A}
|
39 | 94 | ) where {F,T<:Base.HWReal,A}
|
40 | 95 | N = length(y)
|
41 |
| - ptry = VectorizationBase.zero_offsets(stridedpointer(y)) |
42 |
| - ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args)) |
43 |
| - ptr∂y = VectorizationBase.zero_offsets.(stridedpointer.(∂y)) |
44 |
| - |
| 96 | + ptry = zero_offsets(stridedpointer(y)) |
| 97 | + ptrargs = map(zero_offsets, map(stridedpointer, args)) |
| 98 | + ptr∂y = map(zero_offsets, map(stridedpointer, ∂y)) |
45 | 99 | i = 0
|
46 | 100 | V = VectorizationBase.pick_vector_width(T)
|
47 | 101 | W = Int(V)
|
48 |
| - st = VectorizationBase.static_sizeof(T) |
49 |
| - zero_index = MM{W}(StaticInt(0), st) |
50 | 102 | while i < vsub_nsw(N, ((W << 2) - 1))
|
51 | 103 | index = VectorizationBase.Unroll{1,W,4,1,W,zero(UInt)}((i,))
|
52 |
| - v = f(init_dual(vload.(ptrargs, index))...) |
| 104 | + v = f(init_dual(map(Base.Fix2(vload, index), ptrargs))...) |
53 | 105 | dual_store!(ptr∂y, ptry, v, index)
|
54 | 106 | i = vadd_nw(i, 4W)
|
55 | 107 | end
|
56 | 108 | while i < vsub_nsw(N, (W - 1))
|
57 |
| - vᵣ = f(init_dual(vload.(ptrargs, ((MM{W}(i),),)))...) |
| 109 | + loader = Base.Fix2(vload, (MM{W}(i),)) |
| 110 | + vᵣ = f(init_dual(map(loader, ptrargs))...) |
58 | 111 | dual_store!(ptr∂y, ptry, vᵣ, (MM{W}(i),))
|
59 | 112 | i = vadd_nw(i, W)
|
60 | 113 | end
|
61 | 114 | if i < N
|
62 | 115 | m = mask(T, N & (W - 1))
|
| 116 | + mloader = let i = i, m = m |
| 117 | + p -> vload(p, (MM{W}(i),), m) |
| 118 | + end |
63 | 119 | dual_store!(
|
64 | 120 | ptr∂y,
|
65 | 121 | ptry,
|
66 |
| - f(init_dual(vload.(ptrargs, ((MM{W}(i),),), m))...), |
| 122 | + f(init_dual(map(mloader, ptrargs))...), |
67 | 123 | (MM{W}(i),),
|
68 | 124 | m
|
69 | 125 | )
|
@@ -109,6 +165,7 @@ for f in (:vmapt, :vmapnt, :vmapntt)
|
109 | 165 | f::F,
|
110 | 166 | args::Vararg{Any,K}
|
111 | 167 | ) where {F,K}
|
112 |
| - ChainRulesCore.rrule(typeof(vmap), f, args...) |
| 168 | + ChainRulesCore.rrule(typeof($vmap), f, args...) |
113 | 169 | end
|
114 | 170 | end
|
| 171 | +end |
0 commit comments