Skip to content

Commit 6a19554

Browse files
authored
Support specializing on functions (#615)
* specialize on function in jacobian * specilize on function parameters for derivative, gradient, hessian
1 parent 76335e6 commit 6a19554

File tree

4 files changed

+32
-33
lines changed

4 files changed

+32
-33
lines changed

src/derivative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ stored in `y`.
2222
2323
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
2424
"""
25-
@inline function derivative(f!, y::AbstractArray, x::Real,
26-
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
25+
@inline function derivative(f!::F, y::AbstractArray, x::Real,
26+
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
2727
require_one_based_indexing(y)
2828
CHK && checktag(T, f!, x)
2929
ydual = cfg.duals
@@ -60,8 +60,8 @@ called as `f!(y, x)` where the result is stored in `y`.
6060
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
6161
"""
6262
@inline function derivative!(result::Union{AbstractArray,DiffResult},
63-
f!, y::AbstractArray, x::Real,
64-
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
63+
f!::F, y::AbstractArray, x::Real,
64+
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
6565
result isa DiffResult ? require_one_based_indexing(y) : require_one_based_indexing(result, y)
6666
CHK && checktag(T, f!, x)
6767
ydual = cfg.duals

src/gradient.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ This method assumes that `isa(f(x), Real)`.
1313
1414
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
1515
"""
16-
function gradient(f, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
16+
function gradient(f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
1717
require_one_based_indexing(x)
1818
CHK && checktag(T, f, x)
1919
if chunksize(cfg) == length(x)
@@ -43,13 +43,13 @@ function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArr
4343
return result
4444
end
4545

46-
@inline gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
47-
@inline gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
48-
@inline gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
46+
@inline gradient(f::F, x::StaticArray) where F = vector_mode_gradient(f, x)
47+
@inline gradient(f::F, x::StaticArray, cfg::GradientConfig) where F = gradient(f, x)
48+
@inline gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient(f, x)
4949

50-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
51-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
52-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
50+
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_gradient!(result, f, x)
51+
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where F = gradient!(result, f, x)
52+
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient!(result, f, x)
5353

5454
gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))
5555

src/hessian.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ This method assumes that `isa(f(x), Real)`.
1111
1212
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
1313
"""
14-
function hessian(f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
14+
function hessian(f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T,CHK}
1515
require_one_based_indexing(x)
1616
CHK && checktag(T, f, x)
1717
∇f = y -> gradient(f, y, cfg.gradient_config, Val{false}())
@@ -28,7 +28,7 @@ This method assumes that `isa(f(x), Real)`.
2828
2929
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
3030
"""
31-
function hessian!(result::AbstractArray, f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
31+
function hessian!(result::AbstractArray, f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
3232
require_one_based_indexing(result, x)
3333
CHK && checktag(T, f, x)
3434
∇f = y -> gradient(f, y, cfg.gradient_config, Val{false}())
@@ -63,26 +63,25 @@ because `isa(result, DiffResult)`, `cfg` is constructed as `HessianConfig(f, res
6363
6464
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
6565
"""
66-
function hessian!(result::DiffResult, f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, result, x), ::Val{CHK}=Val{true}()) where {T,CHK}
67-
require_one_based_indexing(x)
66+
function hessian!(result::DiffResult, f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, result, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
6867
CHK && checktag(T, f, x)
6968
∇f! = InnerGradientForHess(result, cfg, f)
7069
jacobian!(DiffResults.hessian(result), ∇f!, DiffResults.gradient(result), x, cfg.jacobian_config, Val{false}())
7170
return ∇f!.result
7271
end
7372

74-
hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
75-
hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
76-
hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
73+
hessian(f::F, x::StaticArray) where F = jacobian(y -> gradient(f, y), x)
74+
hessian(f::F, x::StaticArray, cfg::HessianConfig) where F = hessian(f, x)
75+
hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian(f, x)
7776

78-
hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
77+
hessian!(result::AbstractArray, f::F, x::StaticArray) where F = jacobian!(result, y -> gradient(f, y), x)
7978

80-
hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
79+
hessian!(result::MutableDiffResult, f::F, x::StaticArray) where F = hessian!(result, f, x, HessianConfig(f, result, x))
8180

82-
hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
83-
hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
81+
hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where F = hessian!(result, f, x)
82+
hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian!(result, f, x)
8483

85-
function hessian!(result::ImmutableDiffResult, f, x::StaticArray)
84+
function hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where F
8685
T = typeof(Tag(f, eltype(x)))
8786
d1 = dualize(T, x)
8887
d2 = dualize(T, d1)

src/jacobian.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
1515
1616
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
1717
"""
18-
function jacobian(f, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
18+
function jacobian(f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
1919
require_one_based_indexing(x)
2020
CHK && checktag(T, f, x)
2121
if chunksize(cfg) == length(x)
@@ -33,7 +33,7 @@ stored in `y`.
3333
3434
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
3535
"""
36-
function jacobian(f!, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
36+
function jacobian(f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
3737
require_one_based_indexing(y, x)
3838
CHK && checktag(T, f!, x)
3939
if chunksize(cfg) == length(x)
@@ -54,7 +54,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
5454
5555
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
5656
"""
57-
function jacobian!(result::Union{AbstractArray,DiffResult}, f, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
57+
function jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
5858
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
5959
CHK && checktag(T, f, x)
6060
if chunksize(cfg) == length(x)
@@ -75,7 +75,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
7575
7676
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
7777
"""
78-
function jacobian!(result::Union{AbstractArray,DiffResult}, f!, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T,CHK}
78+
function jacobian!(result::Union{AbstractArray,DiffResult}, f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
7979
result isa DiffResult ? require_one_based_indexing(y, x) : require_one_based_indexing(result, y, x)
8080
CHK && checktag(T, f!, x)
8181
if chunksize(cfg) == length(x)
@@ -86,13 +86,13 @@ function jacobian!(result::Union{AbstractArray,DiffResult}, f!, y::AbstractArray
8686
return result
8787
end
8888

89-
@inline jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
90-
@inline jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
91-
@inline jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
89+
@inline jacobian(f::F, x::StaticArray) where F = vector_mode_jacobian(f, x)
90+
@inline jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian(f, x)
91+
@inline jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian(f, x)
9292

93-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
94-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
95-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
93+
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_jacobian!(result, f, x)
94+
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian!(result, f, x)
95+
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian!(result, f, x)
9696

9797
jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))
9898

0 commit comments

Comments
 (0)