-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathapiutils.jl
109 lines (91 loc) · 3.18 KB
/
apiutils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
####################
# value extraction #
####################
@inline extract_value!(::Type{T}, out::DiffResult, ydual) where {T} =
DiffResults.value!(d -> value(T,d), out, ydual)
@inline extract_value!(::Type{T}, out, ydual) where {T} = out # ???
@inline function extract_value!(::Type{T}, out, y, ydual) where {T}
map!(d -> value(T,d), y, ydual)
copy_value!(out, y)
end
@inline copy_value!(out::DiffResult, y) = DiffResults.value!(out, y)
@inline copy_value!(out, y) = out
###################################
# vector mode function evaluation #
###################################
function vector_mode_dual_eval!(f::F, cfg::Union{JacobianConfig,GradientConfig}, x) where {F}
xdual = cfg.duals
seed!(xdual, x, cfg.seeds)
return f(xdual)
end
function vector_mode_dual_eval!(f!::F, cfg::JacobianConfig, y, x) where {F}
ydual, xdual = cfg.duals
seed!(xdual, x, cfg.seeds)
seed!(ydual, y)
f!(ydual, xdual)
return ydual
end
##################################
# seed construction/manipulation #
##################################
@generated function construct_seeds(::Type{Partials{N,V}}) where {N,V}
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)
end
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
duals .= Dual{T,V,N}.(x, Ref(seed))
return duals
end
function seed!(duals::Array{Dual{T,V,N}}, x,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
@inbounds for i in eachindex(duals)
duals[i] = Dual{T,V,N}(x[i], seed)
end
return duals
end
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
dual_inds = 1:N
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
return duals
end
function seed!(duals::Array{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
@inbounds for i in 1:N
duals[i] = Dual{T,V,N}(x[i], seeds[i])
end
return duals
end
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
offset = index - 1
dual_inds = (1:N) .+ offset
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
return duals
end
function seed!(duals::Array{Dual{T,V,N}}, x, index,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
offset = index - 1
@inbounds for i in 1:N
j = i + offset
duals[j] = Dual{T,V,N}(x[j], seed)
end
return duals
end
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
offset = index - 1
seed_inds = 1:chunksize
dual_inds = seed_inds .+ offset
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds))
return duals
end
function seed!(duals::Array{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
offset = index - 1
@inbounds for i in 1:chunksize
j = i + offset
duals[j] = Dual{T,V,N}(x[j], seeds[i])
end
return duals
end