Skip to content

Commit e42d89d

Browse files
simsuracewilltebbuttWill Tebbutt
authored
Resolve type stability of evaluation of KernelSum (JuliaGaussianProcesses#459)
* Test type stability of evaluation of `KernelSum` * Add hack * Implement @willtebbutt's suggestion * Revert strange hack * Implement reviewdog's suggestion * Apply reviewdog's spacing preferences * Add comment Co-authored-by: Will Tebbutt <[email protected]> * Fix tests * Bump patch Co-authored-by: Will Tebbutt <[email protected]> Co-authored-by: Will Tebbutt <[email protected]> Co-authored-by: Will Tebbutt <[email protected]>
1 parent 4ce3e87 commit e42d89d

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.44"
3+
version = "0.10.45"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/kernels/kernelsum.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,25 @@ end
4343

4444
Base.length(k::KernelSum) = length(k.kernels)
4545

46-
::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)
46+
_sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))
47+
_sum(f::Tf, x::Tuple{Tx}) where {Tf,Tx} = f(x[1])
48+
49+
::KernelSum)(x, y) = _sum(k -> k(x, y), κ.kernels)
4750

4851
function kernelmatrix::KernelSum, x::AbstractVector)
49-
return sum(kernelmatrix(k, x) for k in κ.kernels)
52+
return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
5053
end
5154

5255
function kernelmatrix::KernelSum, x::AbstractVector, y::AbstractVector)
53-
return sum(kernelmatrix(k, x, y) for k in κ.kernels)
56+
return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
5457
end
5558

5659
function kernelmatrix_diag::KernelSum, x::AbstractVector)
57-
return sum(kernelmatrix_diag(k, x) for k in κ.kernels)
60+
return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
5861
end
5962

6063
function kernelmatrix_diag::KernelSum, x::AbstractVector, y::AbstractVector)
61-
return sum(kernelmatrix_diag(k, x, y) for k in κ.kernels)
64+
return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
6265
end
6366

6467
function Base.show(io::IO, κ::KernelSum)

test/kernels/kernelsum.jl

+20
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,24 @@
1818
end
1919

2020
test_params(k1 + k2, (k1, k2))
21+
22+
# Regression tests for https://github.com//issues/458
23+
@testset "Type stability" begin
24+
function check_type_stability(k)
25+
@test (@inferred k(0.1, 0.2)) isa Real
26+
x = rand(10)
27+
y = rand(10)
28+
@test (@inferred kernelmatrix(k, x)) isa Matrix{<:Real}
29+
@test (@inferred kernelmatrix(k, x, y)) isa Matrix{<:Real}
30+
@test (@inferred kernelmatrix_diag(k, x)) isa Vector{<:Real}
31+
@test (@inferred kernelmatrix_diag(k, x, y)) isa Vector{<:Real}
32+
end
33+
@testset for k in (
34+
RBFKernel() + RBFKernel() * LinearKernel(),
35+
RBFKernel() + RBFKernel() * ExponentialKernel(),
36+
RBFKernel() * (LinearKernel() + ExponentialKernel()),
37+
)
38+
check_type_stability(k)
39+
end
40+
end
2141
end

0 commit comments

Comments
 (0)