diff --git a/Project.toml b/Project.toml index 3c0045cb4..5ded7e0da 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.17" +version = "0.10.18" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 37241abc0..c0290b390 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -105,8 +105,8 @@ include("kernels/neuralkernelnetwork.jl") include("approximations/nystrom.jl") include("generic.jl") -include("mokernels/mokernel.jl") include("mokernels/moinput.jl") +include("mokernels/mokernel.jl") include("mokernels/independent.jl") include("mokernels/slfm.jl") include("mokernels/intrinsiccoregion.jl") diff --git a/src/chainrules.jl b/src/chainrules.jl index 55de7ddf9..4b69a827f 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -26,7 +26,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X, Y; dims=dims) - function pairwise_pullback(::AbstractMatrix) + function pairwise_pullback(::Any) return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent() end return P, pairwise_pullback @@ -36,7 +36,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X; dims=dims) - function pairwise_pullback(::AbstractMatrix) + function pairwise_pullback(::Any) return NoTangent(), NoTangent(), ZeroTangent() end return P, pairwise_pullback @@ -46,7 +46,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix ) C = Distances.colwise(d, X, Y) - function colwise_pullback(::AbstractVector) + function colwise_pullback(::Any) return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent() end return C, colwise_pullback @@ -70,7 +70,7 @@ function ChainRulesCore.rrule( dims=2, ) P = Distances.pairwise(d, X, Y; dims=dims) - function pairwise_pullback_cols(Δ::AbstractMatrix) + function pairwise_pullback_cols(Δ::Any) if dims == 1 return NoTangent(), NoTangent(), Δ * Y, Δ' * X else @@ -84,7 +84,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X; dims=dims) - function pairwise_pullback_cols(Δ::AbstractMatrix) + function pairwise_pullback_cols(Δ::Any) if dims == 1 return NoTangent(), NoTangent(), 2 * Δ * X else @@ -98,7 +98,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix ) C = Distances.colwise(d, X, Y) - function colwise_pullback(Δ::AbstractVector) + function colwise_pullback(Δ::Any) return NoTangent(), NoTangent(), Δ' .* Y, Δ' .* X end return C, colwise_pullback diff --git a/src/matrix/kernelkroneckermat.jl b/src/matrix/kernelkroneckermat.jl index 26f7d65a9..1a25c7618 100644 --- a/src/matrix/kernelkroneckermat.jl +++ b/src/matrix/kernelkroneckermat.jl @@ -4,6 +4,7 @@ using .Kronecker: Kronecker export kernelkronmat +export kronecker_kernelmatrix function kernelkronmat(κ::Kernel, X::AbstractVector, dims::Int) @assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))" @@ -25,3 +26,42 @@ end k(x,x') = ∏ᵢᴰ k(xᵢ,x'ᵢ) """ @inline iskroncompatible(κ::Kernel) = false # Default return for kernels + +function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs) + return Kronecker.kronecker(Kfeatures, Koutputs) +end + +function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs) + return Kronecker.kronecker(Koutputs, Kfeatures) +end + +function kronecker_kernelmatrix( + k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, + x::IsotopicMOInputsUnion, + y::IsotopicMOInputsUnion, +) + @assert x.out_dim == y.out_dim + Kfeatures = kernelmatrix(k.kernel, x.x, y.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs) +end + +function kronecker_kernelmatrix( + k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion +) + Kfeatures = kernelmatrix(k.kernel, x.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs) +end + +function kronecker_kernelmatrix( + k::MOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion +) + return throw( + ArgumentError("This kernel does not support a lazy kronecker kernelmatrix.") + ) +end + +function kronecker_kernelmatrix(k::MOKernel, x::IsotopicMOInputsUnion) + return kronecker_kernelmatrix(k, x, x) +end diff --git a/src/mokernels/independent.jl b/src/mokernels/independent.jl index d7dde0cba..ec830d395 100644 --- a/src/mokernels/independent.jl +++ b/src/mokernels/independent.jl @@ -27,41 +27,28 @@ function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,I return κ.kernel(x, y) * (px == py) end -function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, B) - return kron(Kfeatures, B) -end - -function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, B) - return kron(B, Kfeatures) -end +_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim) function kernelmatrix( - k::IndependentMOKernel, x::MOI, y::MOI -) where {MOI<:IsotopicMOInputsUnion} + k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion +) @assert x.out_dim == y.out_dim Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - mtype = eltype(Kfeatures) - return _kernelmatrix_kron_helper(x, Kfeatures, Eye{mtype}(x.out_dim)) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs) end if VERSION >= v"1.6" - function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, B) - return kron!(K, Kfeatures, B) - end - - function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, B) - return kron!(K, B, Kfeatures) - end - function kernelmatrix!( - K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI - ) where {MOI<:IsotopicMOInputsUnion} + K::AbstractMatrix, + k::IndependentMOKernel, + x::IsotopicMOInputsUnion, + y::IsotopicMOInputsUnion, + ) @assert x.out_dim == y.out_dim - Ktmp = kernelmatrix(k.kernel, x.x, y.x) - mtype = eltype(Ktmp) - return _kernelmatrix_kron_helper!( - K, x, Ktmp, Matrix{mtype}(I, x.out_dim, x.out_dim) - ) + Kfeatures = kernelmatrix(k.kernel, x.x, y.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs) end end diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index db1050fe4..7e5b884cd 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -42,21 +42,31 @@ function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{ return k.B[px, py] * k.kernel(x, y) end +function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim) + @assert size(k.B) == (out_dim, out_dim) + return k.B +end + function kernelmatrix( - k::IntrinsicCoregionMOKernel, x::MOI, y::MOI -) where {MOI<:IsotopicMOInputsUnion} + k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion +) @assert x.out_dim == y.out_dim Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - return _kernelmatrix_kron_helper(x, Kfeatures, k.B) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs) end if VERSION >= v"1.6" function kernelmatrix!( - K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI - ) where {MOI<:IsotopicMOInputsUnion} + K::AbstractMatrix, + k::IntrinsicCoregionMOKernel, + x::IsotopicMOInputsUnion, + y::IsotopicMOInputsUnion, + ) @assert x.out_dim == y.out_dim Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - return _kernelmatrix_kron_helper!(K, x, Kfeatures, k.B) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs) end end diff --git a/src/mokernels/mokernel.jl b/src/mokernels/mokernel.jl index 2ebb07aad..acf0b78d5 100644 --- a/src/mokernels/mokernel.jl +++ b/src/mokernels/mokernel.jl @@ -4,3 +4,21 @@ Abstract type for kernels with multiple outpus. """ abstract type MOKernel <: Kernel end + +function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs) + return kron(Kfeatures, Koutputs) +end + +function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs) + return kron(Koutputs, Kfeatures) +end + +if VERSION >= v"1.6" + function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs) + return kron!(K, Kfeatures, Koutputs) + end + + function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs) + return kron!(K, Koutputs, Kfeatures) + end +end diff --git a/test/matrix/kernelkroneckermat.jl b/test/matrix/kernelkroneckermat.jl index 8cb12ac80..9207e10cd 100644 --- a/test/matrix/kernelkroneckermat.jl +++ b/test/matrix/kernelkroneckermat.jl @@ -7,4 +7,50 @@ @test all(collect(kernelkronmat(k, collect(x), 2)) .≈ kernelmatrix(k, X; obsdim=1)) @test all(collect(kernelkronmat(k, [x, x])) .≈ kernelmatrix(k, X; obsdim=1)) @test_throws AssertionError kernelkronmat(LinearKernel(), collect(x), 2) + + @testset "lazy kernelmatrix" begin + rng = MersenneTwister(123) + + dims = (in=3, out=2, obs=3) + r = 1 + + A = randn(dims.out, r) + B = A * transpose(A) + Diagonal(rand(dims.out)) + + # XIF = [(rand(dims.in), rand(1:(dims.out))) for i in 1:(dims.obs)] + x = [rand(dims.in) for _ in 1:2] + XIF = KernelFunctions.MOInputIsotopicByFeatures(x, dims.out) + XIO = KernelFunctions.MOInputIsotopicByOutputs(x, dims.out) + y = [rand(dims.in) for _ in 1:2] + YIF = KernelFunctions.MOInputIsotopicByFeatures(y, dims.out) + YIO = KernelFunctions.MOInputIsotopicByOutputs(y, dims.out) + + skernel = GaussianKernel() + kIndMO = IndependentMOKernel(skernel) + + A = randn(dims.out, r) + B = A * transpose(A) + Diagonal(rand(dims.out)) + icoregionkernel = IntrinsicCoregionMOKernel(skernel, B) + + function test_kronecker_kernelmatrix(k, x) + res = kronecker_kernelmatrix(k, x) + @test typeof(res) <: Kronecker.KroneckerProduct + @test res == kernelmatrix(k, x) + end + function test_kronecker_kernelmatrix(k, x, y) + res = kronecker_kernelmatrix(k, x, y) + @test typeof(res) <: Kronecker.KroneckerProduct + @test res == kernelmatrix(k, x, y) + end + + for k in [kIndMO, icoregionkernel], x in [XIF, XIO] + test_kronecker_kernelmatrix(k, x) + end + for k in [kIndMO, icoregionkernel], (x, y) in ([XIF, YIF], [XIO, YIO]) + test_kronecker_kernelmatrix(k, x, y) + end + + struct TestMOKernel <: MOKernel end + @test_throws ArgumentError kronecker_kernelmatrix(TestMOKernel(), XIF) + end end