Skip to content

Commit 4fb8d88

Browse files
authored
6 fix zero order partial problem (#8)
* avoid problem * add order back in * compiler friendly fullderivative * testing * empty tuple * remove dead code * lower case as constructor is no longer varargs
1 parent 8b27dfa commit 4fb8d88

File tree

4 files changed

+40
-22
lines changed

4 files changed

+40
-22
lines changed

src/diffKernel.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ then julia would not know whether to use
2828
=#
2929
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
3030
(k::T)(x::DiffPt, y::DiffPt) = _evaluate(k, x, y)
31-
(k::T)(x::DiffPt, y) = _evaluate(k, x,(y, Partial()))
32-
(k::T)(x, y::DiffPt) = _evaluate(k, (x, Partial()), y)
31+
(k::T)(x::DiffPt, y) = _evaluate(k, x,(y, partial()))
32+
(k::T)(x, y::DiffPt) = _evaluate(k, (x, partial()), y)
3333
end

src/partial.jl

+26-19
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
const IndexType = Union{Int,Base.AbstractCartesianIndex}
2-
3-
struct Partial{Order,T<:IndexType}
4-
indices::NTuple{Order,T}
2+
struct Partial{Order,T<:Tuple{Vararg{IndexType,Order}}}
3+
indices::T
54
end
65

7-
# TODO: this is not ideal... how does NTuple{0,Int} <: Tuple{} work??
8-
Partial() = Partial{0,Int}(())
9-
function Partial(indices::Integer...)
10-
return Partial{length(indices),Int}(indices)
11-
end
12-
function Partial(indices::Base.AbstractCartesianIndex...)
13-
return Partial{length(indices),Base.AbstractCartesianIndex}(indices)
6+
Partial(::Tuple{}) = Partial{0,Tuple{}}(())
7+
function Partial(indices::Tuple{Vararg{T}}) where {T<:IndexType}
8+
Ord = length(indices)
9+
return Partial{Ord,NTuple{Ord,T}}(indices)
1410
end
15-
partial(indices...) = Partial(indices...)
11+
partial(indices::Tuple{Vararg{T}}) where {T<:IndexType} = Partial(indices)
12+
partial(indices::IndexType...) = Partial(indices)
1613

1714
## show helpers
1815

@@ -22,18 +19,18 @@ lower_digits(idx::Base.AbstractCartesianIndex) = join(map(lower_digits, Tuple(id
2219
### Fallbacks
2320
compact_representation(p::Partial) = compact_representation(MIME"text/plain"(), p)
2421
compact_representation(::MIME, p::Partial) = compact_representation(p)
25-
detailed_representation(p::Partial) = """: Partial($(join(p.indices,",")))"""
26-
detailed_representation(p::Partial{0}) = """: Partial() a zero order derivative"""
22+
detailed_representation(p::Partial) = """: partial($(join(p.indices,",")))"""
23+
detailed_representation(p::Partial{0,Tuple{}}) = """: partial() a zero order derivative"""
2724

2825
### text/plain
29-
compact_representation(::MIME"text/plain", ::Partial{0}) = "id"
26+
compact_representation(::MIME"text/plain", ::Partial{0,Tuple{}}) = "id"
3027
function compact_representation(::MIME"text/plain", p::Partial)
3128
lower_numbers = map(lower_digits, p.indices)
3229
return join(["$(x)" for x in lower_numbers])
3330
end
3431

3532
### text/html
36-
compact_representation(::MIME"text/html", ::Partial{0}) = """<span class="text-muted" title="a zero order derivative">id</span>"""
33+
compact_representation(::MIME"text/html", ::Partial{0,Tuple{}}) = """<span class="text-muted" title="a zero order derivative">id</span>"""
3734
function compact_representation(::MIME"text/html", p::Partial)
3835
return join(map(n -> "∂<sub>$(n)</sub>", Tuple(p.indices)), "")
3936
end
@@ -54,12 +51,22 @@ end
5451

5552
const DiffPt{T} = Tuple{T,Partial}
5653

57-
gradient(dim::Integer) = mappedarray(partial, Base.OneTo(dim))
58-
hessian(dim::Integer) = mappedarray(partial, productArray(Base.OneTo(dim), Base.OneTo(dim)))
59-
fullderivative(order::Integer,dim::Integer) = mappedarray(partial, productArray(ntuple(_->Base.OneTo(dim), order)...))
54+
function fullderivative(::Val{order}, input_indices::AbstractVector{Int}) where {order}
55+
return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
56+
end
57+
fullderivative(::Val{order}, dim::Integer) where {order} = fullderivative(Val{order}(), Base.OneTo(dim))
58+
function fullderivative(::Val{order}, input_indices::AbstractArray{T,N}) where {order,N,T<:Base.AbstractCartesianIndex{N}}
59+
return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
60+
end
61+
62+
gradient(input_indices::AbstractArray) = fullderivative(Val(1), input_indices)
63+
gradient(dim::Integer) = fullderivative(Val(1), dim)
64+
65+
hessian(input_indices::AbstractArray) = fullderivative(Val(2), input_indices)
66+
hessian(dim::Integer) = fullderivative(Val(2), dim)
6067

6168
# idea: lazy mappings can be undone (extract original range -> towards a specialization speedup of broadcasting over multiple derivatives using backwardsdiff)
62-
const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1,Int},1,T,typeof(partial)}
69+
const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1,Tuple{Int}},1,T,typeof(partial)}
6370
function extract_range(p_map::MappedPartialVec{T}) where {T<:AbstractUnitRange{Int}}
6471
return p_map.data::T
6572
end

test/partial.jl

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testset "Partial" begin
2+
@test partial() isa DKF.Partial{0,Tuple{}}
3+
@test partial.(1:10) == DKF.gradient(10)
4+
@test partial.(productArray(1:10, 1:10)) == DKF.hessian(10)
5+
@test DKF.gradient(10) isa AbstractArray{DKF.Partial{1,Tuple{Int}},1}
6+
@test DKF.gradient(2) == DKF.fullderivative(Val(1), 2)
7+
@test size(DKF.gradient(CartesianIndices((4,3)))) == (4,3)
8+
@test ndims(DKF.hessian(CartesianIndices((2,2,2)))) == 6
9+
end

test/runtests.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel
2-
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, DiffPt, partial, Partial
2+
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, DiffPt, partial
3+
using ProductArrays: productArray
34
using Test
45

56
"""
67
List of Testfiles without extension. `\$(test).jl"` should be a file for every test in AVAILABLE_TESTS
78
"""
89
const AVAILABLE_TESTS = [
910
"diffKernel",
11+
"partial"
1012
]
1113

1214
function test_selection()

0 commit comments

Comments
 (0)