Skip to content

Commit ef6d459

Browse files
Fix method ambiguities (JuliaGaussianProcesses#483)
* Fix method ambiguities * Update test/kernels/kernelsum.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Try to fix `map` issues * Define `map` for `ColVecs`/`RowVecs` * Fix ambiguity issues * Better fix * Add back some definitions --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9da7bfd commit ef6d459

11 files changed

+62
-11
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.54"
3+
version = "0.10.55"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/chainrules.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
# Note that this is type piracy as the derivative should be NaN for x == y.
44
function ChainRulesCore.frule(
5-
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector
5+
(_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any},
6+
d::Distances.Euclidean,
7+
x::AbstractVector,
8+
y::AbstractVector,
69
)
710
Δ = x - y
811
D = sqrt(sum(abs2, Δ))

src/kernels/overloads.jl

+8
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,13 @@ for (M, op, T) in (
1818

1919
$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
2020
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))
21+
22+
# Fix method ambiguity issues
23+
function $M.$op(ks1::$T, ks2::$T{<:AbstractVector{<:Kernel}})
24+
return $T(vcat(collect(ks1.kernels), ks2.kernels))
25+
end
26+
function $M.$op(ks1::$T{<:AbstractVector{<:Kernel}}, ks2::$T)
27+
return $T(vcat(ks1.kernels, collect(ks2.kernels)))
28+
end
2129
end
2230
end

src/transform/chaintransform.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ function ChainTransform(v, θ::AbstractVector)
3434
end
3535

3636
Base.:(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁))
37-
Base.:(t::Transform, tc::ChainTransform) = ChainTransform(tuple(tc.transforms..., t))
38-
Base.:(tc::ChainTransform, t::Transform) = ChainTransform(tuple(t, tc.transforms...))
37+
Base.:(t::Transform, tc::ChainTransform) = ChainTransform((tc.transforms..., t))
38+
Base.:(tc::ChainTransform, t::Transform) = ChainTransform((t, tc.transforms...))
39+
function Base.:(tc1::ChainTransform, tc2::ChainTransform)
40+
return ChainTransform((tc2.transforms..., tc1.transforms...))
41+
end
3942

4043
(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)
4144

src/transform/transform.jl

+18-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,21 @@ abstract type Transform end
88
# We introduce our own _map for Transform so that we can work around
99
# https://github.com/FluxML/Zygote.jl/issues/646 and define our own pullback
1010
# (see zygoterules.jl)
11-
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
12-
_map(t::Transform, x::AbstractVector) = t.(x)
11+
Base.map(t::Transform, x::ColVecs) = _map(t, x)
12+
Base.map(t::Transform, x::RowVecs) = _map(t, x)
13+
14+
# Fallback
15+
# No separate methods for `x::ColVecs` and `x::RowVecs` to avoid method ambiguities
16+
function _map(t::Transform, x::AbstractVector)
17+
# Avoid stackoverflow
18+
if x isa RowVecs
19+
return map(t, eachrow(x.X))
20+
elseif x isa ColVecs
21+
return map(t, eachcol(x.X))
22+
else
23+
return map(t, x)
24+
end
25+
end
1326

1427
"""
1528
IdentityTransform()
@@ -19,6 +32,9 @@ Transformation that returns exactly the input.
1932
struct IdentityTransform <: Transform end
2033

2134
(t::IdentityTransform)(x) = x
35+
36+
# More efficient implementation than `map(IdentityTransform(), x)`
37+
# Introduces, however, discrepancy between `map` and `_map`
2238
_map(::IdentityTransform, x::AbstractVector) = x
2339

2440
### TODO Maybe defining adjoints could help but so far it's not working

src/utils.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ _to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :))
101101

102102
pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
103103
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
104-
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
104+
function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::ColVecs)
105105
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
106106
end
107-
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
107+
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector{<:AbstractVector{<:Real}})
108108
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
109109
end
110110
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
@@ -172,10 +172,10 @@ dim(x::RowVecs) = size(x.X, 2)
172172

173173
pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
174174
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
175-
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
175+
function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::RowVecs)
176176
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
177177
end
178-
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
178+
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector{<:AbstractVector{<:Real}})
179179
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
180180
end
181181
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)

test/kernels/kernelproduct.jl

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
k2 = SqExponentialKernel()
44
k = KernelProduct(k1, k2)
55
@test k == KernelProduct([k1, k2]) == KernelProduct((k1, k2))
6+
for (_k1, _k2) in Iterators.product(
7+
(k1, KernelProduct((k1,)), KernelProduct([k1])),
8+
(k2, KernelProduct((k2,)), KernelProduct([k2])),
9+
)
10+
@test k == _k1 * _k2
11+
end
612
@test length(k) == 2
713
@test string(k) == (
814
"Product of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " *

test/kernels/kernelsum.jl

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
k2 = SqExponentialKernel()
44
k = KernelSum(k1, k2)
55
@test k == KernelSum([k1, k2]) == KernelSum((k1, k2))
6+
for (_k1, _k2) in Iterators.product(
7+
(k1, KernelSum((k1,)), KernelSum([k1])), (k2, KernelSum((k2,)), KernelSum([k2]))
8+
)
9+
@test k == _k1 + _k2
10+
end
611
@test length(k) == 2
712
@test repr(k) == (
813
"Sum of 2 kernels:\n" *

test/kernels/kerneltensorproduct.jl

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313

1414
@test kernel1 == kernel2
1515
@test kernel1.kernels === (k1, k2) === KernelTensorProduct((k1, k2)).kernels
16+
for (_k1, _k2) in Iterators.product(
17+
(k1, KernelTensorProduct((k1,)), KernelTensorProduct([k1])),
18+
(k2, KernelTensorProduct((k2,)), KernelTensorProduct([k2])),
19+
)
20+
@test kernel1 == _k1 _k2
21+
end
1622
@test length(kernel1) == length(kernel2) == 2
1723
@test string(kernel1) == (
1824
"Tensor product of 2 kernels:\n" *

test/runtests.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ include("test_utils.jl")
149149
if GROUP == "" || GROUP == "Others"
150150
include("utils.jl")
151151

152-
@test isempty(detect_unbound_args(KernelFunctions))
152+
@testset "general" begin
153+
@test isempty(detect_unbound_args(KernelFunctions))
154+
@test isempty(detect_ambiguities(KernelFunctions))
155+
end
153156

154157
@testset "distances" begin
155158
include("distances/pairwise.jl")

test/transform/chaintransform.jl

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Check composition constructors.
1313
@test (tf ChainTransform([tp])).transforms == (tp, tf)
1414
@test (ChainTransform([tf]) tp).transforms == (tp, tf)
15+
@test (ChainTransform([tf]) ChainTransform([tp])).transforms == (tp, tf)
1516

1617
# Verify correctness.
1718
x = ColVecs(randn(rng, 2, 3))

0 commit comments

Comments
 (0)