Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d5a5a92
Restore additions
Crown421 Aug 30, 2021
a7ad44e
Improvements for lazy kron
Crown421 Aug 30, 2021
dde0c08
Remove unneeded lines
Crown421 Aug 30, 2021
a2a345f
Merge branch 'JuliaGaussianProcesses:master' into mo-lazy
Crown421 Aug 31, 2021
88cb606
Small experiment with overwriting
Crown421 Sep 4, 2021
af90bf1
Reorder and overwrite
Crown421 Sep 4, 2021
c6e9f06
Format and kernelmatrix!
Crown421 Sep 4, 2021
5c45aed
Merge branch 'mo-lazy' of github.com:Crown421/KernelFunctions.jl into…
Crown421 Sep 4, 2021
62e4b5e
Reinstate separate method
Crown421 Sep 15, 2021
d9287d7
Adding tests
Crown421 Sep 15, 2021
a732763
Merge branch 'master' into mo-lazy
Crown421 Sep 15, 2021
7f23921
Duplicate code for readability
Crown421 Sep 22, 2021
1e898b7
Merge branch 'mo-lazy' of github.com:Crown421/KernelFunctions.jl into…
Crown421 Sep 22, 2021
ceab3cf
Format
Crown421 Sep 22, 2021
8ed6be3
Remove comment and patch bump
Crown421 Sep 22, 2021
f46bd61
Change kernelmatrix!
Crown421 Sep 22, 2021
9540fce
Change to output covariance type
Crown421 Sep 22, 2021
09cd20e
Change to output covariance type - revert
Crown421 Sep 22, 2021
e0bce9d
Merge branch 'mo-lazy' of github.com:Crown421/KernelFunctions.jl into…
Crown421 Sep 22, 2021
204d99c
Revert "Change to output covariance type - revert"
Crown421 Sep 22, 2021
8b5757a
Revert "Change kernelmatrix!"
Crown421 Sep 22, 2021
dfee74f
Add kernelmatrix! changes again
Crown421 Sep 22, 2021
24f6b49
Change input types for pairwise pullback
Crown421 Sep 22, 2021
a00ba6d
Missing changes to Any
Crown421 Sep 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.17"
version = "0.10.18"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ include("kernels/neuralkernelnetwork.jl")
include("approximations/nystrom.jl")
include("generic.jl")

include("mokernels/mokernel.jl")
include("mokernels/moinput.jl")
include("mokernels/mokernel.jl")
include("mokernels/independent.jl")
include("mokernels/slfm.jl")
include("mokernels/intrinsiccoregion.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback(::AbstractMatrix)
function pairwise_pullback(::Any)
return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()
end
return P, pairwise_pullback
Expand All @@ -36,7 +36,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback(::AbstractMatrix)
function pairwise_pullback(::Any)
return NoTangent(), NoTangent(), ZeroTangent()
end
return P, pairwise_pullback
Expand All @@ -46,7 +46,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(::AbstractVector)
function colwise_pullback(::Any)
return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()
end
return C, colwise_pullback
Expand All @@ -70,7 +70,7 @@ function ChainRulesCore.rrule(
dims=2,
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback_cols(Δ::AbstractMatrix)
function pairwise_pullback_cols(Δ::Any)
if dims == 1
return NoTangent(), NoTangent(), Δ * Y, Δ' * X
else
Expand All @@ -84,7 +84,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback_cols(Δ::AbstractMatrix)
function pairwise_pullback_cols(Δ::Any)
if dims == 1
return NoTangent(), NoTangent(), 2 * Δ * X
else
Expand All @@ -98,7 +98,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(Δ::AbstractVector)
function colwise_pullback(Δ::Any)
return NoTangent(), NoTangent(), Δ' .* Y, Δ' .* X
end
return C, colwise_pullback
Expand Down
40 changes: 40 additions & 0 deletions src/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using .Kronecker: Kronecker

export kernelkronmat
export kronecker_kernelmatrix

function kernelkronmat(κ::Kernel, X::AbstractVector, dims::Int)
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))"
Expand All @@ -25,3 +26,42 @@ end
k(x,x') = ∏ᵢᴰ k(xᵢ,x'ᵢ)
"""
@inline iskroncompatible(κ::Kernel) = false # Default return for kernels

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
return Kronecker.kronecker(Kfeatures, Koutputs)
end

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
return Kronecker.kronecker(Koutputs, Kfeatures)
end

function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel},
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion
)
Kfeatures = kernelmatrix(k.kernel, x.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
k::MOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
return throw(
ArgumentError("This kernel does not support a lazy kronecker kernelmatrix.")
)
end

function kronecker_kernelmatrix(k::MOKernel, x::IsotopicMOInputsUnion)
return kronecker_kernelmatrix(k, x, x)
end
39 changes: 13 additions & 26 deletions src/mokernels/independent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,28 @@ function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,I
return κ.kernel(x, y) * (px == py)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, B)
return kron(Kfeatures, B)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, B)
return kron(B, Kfeatures)
end
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)

function kernelmatrix(
k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
mtype = eltype(Kfeatures)
return _kernelmatrix_kron_helper(x, Kfeatures, Eye{mtype}(x.out_dim))
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, B)
return kron!(K, Kfeatures, B)
end

function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, B)
return kron!(K, B, Kfeatures)
end

function kernelmatrix!(
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
K::AbstractMatrix,
k::IndependentMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
Ktmp = kernelmatrix(k.kernel, x.x, y.x)
mtype = eltype(Ktmp)
return _kernelmatrix_kron_helper!(
K, x, Ktmp, Matrix{mtype}(I, x.out_dim, x.out_dim)
)
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
end
end

Expand Down
22 changes: 16 additions & 6 deletions src/mokernels/intrinsiccoregion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,31 @@ function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{
return k.B[px, py] * k.kernel(x, y)
end

function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
@assert size(k.B) == (out_dim, out_dim)
return k.B
end

function kernelmatrix(
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
return _kernelmatrix_kron_helper(x, Kfeatures, k.B)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
K::AbstractMatrix,
k::IntrinsicCoregionMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, k.B)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
end
end

Expand Down
18 changes: 18 additions & 0 deletions src/mokernels/mokernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,21 @@
Abstract type for kernels with multiple outpus.
"""
abstract type MOKernel <: Kernel end

function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
return kron(Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
return kron(Koutputs, Kfeatures)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
return kron!(K, Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
return kron!(K, Koutputs, Kfeatures)
end
end
46 changes: 46 additions & 0 deletions test/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,50 @@
@test all(collect(kernelkronmat(k, collect(x), 2)) .≈ kernelmatrix(k, X; obsdim=1))
@test all(collect(kernelkronmat(k, [x, x])) .≈ kernelmatrix(k, X; obsdim=1))
@test_throws AssertionError kernelkronmat(LinearKernel(), collect(x), 2)

@testset "lazy kernelmatrix" begin
rng = MersenneTwister(123)

dims = (in=3, out=2, obs=3)
r = 1

A = randn(dims.out, r)
B = A * transpose(A) + Diagonal(rand(dims.out))

# XIF = [(rand(dims.in), rand(1:(dims.out))) for i in 1:(dims.obs)]
x = [rand(dims.in) for _ in 1:2]
XIF = KernelFunctions.MOInputIsotopicByFeatures(x, dims.out)
XIO = KernelFunctions.MOInputIsotopicByOutputs(x, dims.out)
y = [rand(dims.in) for _ in 1:2]
YIF = KernelFunctions.MOInputIsotopicByFeatures(y, dims.out)
YIO = KernelFunctions.MOInputIsotopicByOutputs(y, dims.out)

skernel = GaussianKernel()
kIndMO = IndependentMOKernel(skernel)

A = randn(dims.out, r)
B = A * transpose(A) + Diagonal(rand(dims.out))
icoregionkernel = IntrinsicCoregionMOKernel(skernel, B)

function test_kronecker_kernelmatrix(k, x)
res = kronecker_kernelmatrix(k, x)
@test typeof(res) <: Kronecker.KroneckerProduct
@test res == kernelmatrix(k, x)
end
function test_kronecker_kernelmatrix(k, x, y)
res = kronecker_kernelmatrix(k, x, y)
@test typeof(res) <: Kronecker.KroneckerProduct
@test res == kernelmatrix(k, x, y)
end

for k in [kIndMO, icoregionkernel], x in [XIF, XIO]
test_kronecker_kernelmatrix(k, x)
end
for k in [kIndMO, icoregionkernel], (x, y) in ([XIF, YIF], [XIO, YIO])
test_kronecker_kernelmatrix(k, x, y)
end

struct TestMOKernel <: MOKernel end
@test_throws ArgumentError kronecker_kernelmatrix(TestMOKernel(), XIF)
end
end