Skip to content

Commit 9da7bfd

Browse files
authored
Make spectral_mixture_kernel type stable for StaticArrays (JuliaGaussianProcesses#501)
* Cherry pick from branch * Format and patch bump * Remove StaticArrays from main Project.toml
1 parent 8746034 commit 9da7bfd

File tree

5 files changed

+19
-2
lines changed

5 files changed

+19
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.53"
3+
version = "0.10.54"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/basekernels/sm.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ Here, D is input dimension and A is the number of spectral components.
1212
1313
`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
1414
15+
!!! warning
16+
If you want to make sure that the constructor is type-stable, you should
17+
provide [`StaticArrays`](https://github.com/JuliaArrays/StaticArrays.jl) arguments:
18+
`αs` as a `StaticVector`, `γs` and `ωs` as `StaticMatrix`.
19+
1520
Generalised Spectral Mixture kernel function. This family of functions is dense
1621
in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]
1722
@@ -42,11 +47,12 @@ function spectral_mixture_kernel(
4247
throw(DimensionMismatch("The dimensions of γs ans ωs do not match"))
4348
end
4449

45-
return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
50+
kernels = map(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
4651
a = TransformedKernel(h, LinearTransform'))
4752
b = TransformedKernel(CosineKernel(), LinearTransform'))
4853
return α * a * b
4954
end
55+
return sum(kernels)
5056
end
5157

5258
function spectral_mixture_kernel(

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
17+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1718
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1920
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -32,4 +33,5 @@ PDMats = "0.9, 0.10, 0.11"
3233
ReverseDiff = "1.2"
3334
SpecialFunctions = "0.10, 1, 2"
3435
StableRNGs = "1"
36+
StaticArrays = "1"
3537
Zygote = "0.6.38"

test/basekernels/sm.jl

+8
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@
4848
TestUtils.test_interface(k1, x0, x1, x2)
4949
TestUtils.test_interface(k2, x0, x1, x2)
5050
end
51+
52+
@testset "Type stability given static arrays" begin
53+
αs = @SVector rand(3)
54+
γs = @SMatrix rand(D_in, 3)
55+
ωs = @SMatrix rand(D_in, 3)
56+
@inferred spectral_mixture_kernel(αs, γs, ωs)
57+
end
58+
5159
# test_ADs(x->spectral_mixture_kernel(exp.(x[1:3]), reshape(x[4:18], 5, 3), reshape(x[19:end], 5, 3)), vcat(log.(αs₁), γs[:], ωs[:]), dims = [5,5])
5260
@test_broken "No tests passing (BaseKernel)"
5361
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using PDMats
1010
using Random
1111
using SpecialFunctions
1212
using StableRNGs
13+
using StaticArrays
1314
using Statistics
1415
using Test
1516
using Zygote: Zygote

0 commit comments

Comments
 (0)