diff --git a/Project.toml b/Project.toml index ef8435a0e..9ac9ac813 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.53" +version = "0.10.54" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -35,4 +35,4 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.32, 0.33" TensorCore = "0.1" ZygoteRules = "0.2" -julia = "1.3" +julia = "1.6" diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index e3f5245fb..cfbd8abc6 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -98,8 +98,8 @@ See also: [`ZeroKernel`](@ref) struct ConstantKernel{Tc<:Real} <: SimpleKernel c::Vector{Tc} - function ConstantKernel(; c::Real=1.0) - @check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0") + function ConstantKernel(; c::Real=1.0, check_args::Bool=true) + @check_args(ConstantKernel, (c, c >= zero(c), "c ≥ 0")) return new{typeof(c)}([c]) end end diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index 2061d40f9..0fc4b313f 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -125,14 +125,16 @@ struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel γ::Vector{Tγ} metric::M - function GammaExponentialKernel(γ::Real, metric) - @check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + function GammaExponentialKernel(γ::Real, metric; check_args::Bool=true) + @check_args(GammaExponentialKernel, (γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")) return new{typeof(γ),typeof(metric)}([γ], metric) end end -function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean()) - return GammaExponentialKernel(γ, metric) +function GammaExponentialKernel(; + gamma::Real=1.0, γ::Real=gamma, metric=Euclidean(), check_args::Bool=true +) + return GammaExponentialKernel(γ, metric; check_args) end @functor GammaExponentialKernel diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index 422400ac4..b93933a9a 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -14,13 +14,13 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}. """ struct FBMKernel{T<:Real} <: Kernel h::Vector{T} - function FBMKernel(h::Real) - @check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]") + function FBMKernel(h::Real; check_args::Bool=true) + @check_args(FBMKernel, (h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]")) return new{typeof(h)}([h]) end end -FBMKernel(; h::Real=0.5) = FBMKernel(h) +FBMKernel(; h::Real=0.5, check_args::Bool=true) = FBMKernel(h; check_args) @functor FBMKernel diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index cb051ee0f..e1e7b8486 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -27,13 +27,15 @@ struct MaternKernel{Tν<:Real,M} <: SimpleKernel ν::Vector{Tν} metric::M - function MaternKernel(ν::Real, metric) - @check_args(MaternKernel, ν, ν > zero(ν), "ν > 0") + function MaternKernel(ν::Real, metric; check_args::Bool=true) + @check_args(MaternKernel, (ν, ν > zero(ν), "ν > 0")) return new{typeof(ν),typeof(metric)}([ν], metric) end end -MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric) +function MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean(), check_args::Bool=true) + return MaternKernel(ν, metric; check_args) +end @functor MaternKernel diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index 2758d7f94..8729e9606 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -15,8 +15,10 @@ k(x, x'; r) = \\exp\\bigg(- \\frac{1}{2} \\sum_{i=1}^d \\bigg(\\frac{\\sin\\big( """ struct PeriodicKernel{T} <: SimpleKernel r::Vector{T} - function PeriodicKernel(; r::AbstractVector{<:Real}=ones(Float64, 1)) - @check_args(PeriodicKernel, r, all(ri > zero(ri) for ri in r), "r > 0") + function PeriodicKernel(; + r::AbstractVector{<:Real}=ones(Float64, 1), check_args::Bool=true + ) + @check_args(PeriodicKernel, (r, all(ri > zero(ri) for ri in r), "r > 0")) return new{eltype(r)}(r) end end @@ -28,7 +30,9 @@ PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims) Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`. """ -PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims)) +function PeriodicKernel(T::DataType, dims::Int=1) + return PeriodicKernel(; r=ones(T, dims), check_args=false) +end @functor PeriodicKernel diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index 0dbd2d52b..f2385efa5 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -16,13 +16,13 @@ See also: [`PolynomialKernel`](@ref) struct LinearKernel{Tc<:Real} <: SimpleKernel c::Vector{Tc} - function LinearKernel(c::Real) - @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0") + function LinearKernel(c::Real; check_args::Bool=true) + @check_args(LinearKernel, (c, c >= zero(c), "c ≥ 0")) return new{typeof(c)}([c]) end end -LinearKernel(; c::Real=0.0) = LinearKernel(c) +LinearKernel(; c::Real=0.0, check_args::Bool=true) = LinearKernel(c; check_args) @functor LinearKernel @@ -69,15 +69,20 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel degree::Int c::Vector{Tc} - function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc} - @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") - @check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0") + function PolynomialKernel{Tc}( + degree::Int, c::Vector{Tc}; check_args::Bool=true + ) where {Tc} + @check_args( + PolynomialKernel, + (degree, degree >= one(degree), "degree ≥ 1"), + (c, only(c) >= zero(Tc), "c ≥ 0") + ) return new{Tc}(degree, c) end end -function PolynomialKernel(; degree::Int=2, c::Real=0.0) - return PolynomialKernel{typeof(c)}(degree, [c]) +function PolynomialKernel(; degree::Int=2, c::Real=0.0, check_args::Bool=true) + return PolynomialKernel{typeof(c)}(degree, [c]; check_args) end # The degree of the polynomial kernel is a fixed discrete parameter diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 835fe92ff..4ba0db22d 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -19,14 +19,16 @@ struct RationalKernel{Tα<:Real,M} <: SimpleKernel α::Vector{Tα} metric::M - function RationalKernel(α::Real, metric) - @check_args(RationalKernel, α, α > zero(α), "α > 0") + function RationalKernel(α::Real, metric; check_args::Bool=true) + @check_args(RationalKernel, (α, α > zero(α), "α > 0")) return new{typeof(α),typeof(metric)}([α], metric) end end -function RationalKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) - return RationalKernel(α, metric) +function RationalKernel(; + alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true +) + return RationalKernel(α, metric; check_args) end @functor RationalKernel @@ -83,8 +85,10 @@ struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel α::Vector{Tα} metric::M - function RationalQuadraticKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) - @check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0") + function RationalQuadraticKernel(; + alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true + ) + @check_args(RationalQuadraticKernel, (α, α > zero(α), "α > 0")) return new{typeof(α),typeof(metric)}([α], metric) end end @@ -172,10 +176,18 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel metric::M function GammaRationalKernel(; - alpha::Real=2.0, gamma::Real=1.0, α::Real=alpha, γ::Real=gamma, metric=Euclidean() + alpha::Real=2.0, + gamma::Real=1.0, + α::Real=alpha, + γ::Real=gamma, + metric=Euclidean(), + check_args::Bool=true, ) - @check_args(GammaRationalKernel, α, α > zero(α), "α > 0") - @check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + @check_args( + GammaRationalKernel, + (α, α > zero(α), "α > 0"), + (γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + ) return new{typeof(α),typeof(γ),typeof(metric)}([α], [γ], metric) end end diff --git a/src/basekernels/wiener.jl b/src/basekernels/wiener.jl index 14d330850..980b243ab 100644 --- a/src/basekernels/wiener.jl +++ b/src/basekernels/wiener.jl @@ -39,8 +39,8 @@ The [`WhiteKernel`](@ref) is recovered for ``i = -1``. [^SDH]: Schober, Duvenaud & Hennig (2014). Probabilistic ODE Solvers with Runge-Kutta Means. """ struct WienerKernel{I} <: Kernel - function WienerKernel{I}() where {I} - @check_args(WienerKernel, I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}") + function WienerKernel{I}(; check_args::Bool=true) where {I} + @check_args(WienerKernel, (I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}")) if I == -1 return WhiteKernel() end diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 0a17943ac..6b20c6588 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -16,8 +16,10 @@ struct ScaledKernel{Tk<:Kernel,Tσ²<:Real} <: Kernel σ²::Vector{Tσ²} end -function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real} - @check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0") +function ScaledKernel( + kernel::Tk, σ²::Tσ²=1.0; check_args::Bool=true +) where {Tk<:Kernel,Tσ²<:Real} + @check_args(ScaledKernel, (σ², σ² > zero(Tσ²), "σ² > 0")) return ScaledKernel{Tk,Tσ²}(kernel, [σ²]) end diff --git a/src/mokernels/independent.jl b/src/mokernels/independent.jl index 1f7811b14..554fdb08f 100644 --- a/src/mokernels/independent.jl +++ b/src/mokernels/independent.jl @@ -39,16 +39,14 @@ function kernelmatrix( return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs) end -if VERSION >= v"1.6" - function kernelmatrix!( - K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI - ) where {MOI<:IsotopicMOInputsUnion} - x.out_dim == y.out_dim || - throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) - Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - Koutputs = _mo_output_covariance(k, x.out_dim) - return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) - end +function kernelmatrix!( + K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI +) where {MOI<:IsotopicMOInputsUnion} + x.out_dim == y.out_dim || + throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) + Kfeatures = kernelmatrix(k.kernel, x.x, y.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) end function Base.show(io::IO, k::IndependentMOKernel) diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index 0a940796b..cfa1118f0 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -19,23 +19,25 @@ struct IntrinsicCoregionMOKernel{K<:Kernel,T<:AbstractMatrix} <: MOKernel kernel::K B::T - function IntrinsicCoregionMOKernel{K,T}(kernel::K, B::T) where {K,T} + function IntrinsicCoregionMOKernel{K,T}( + kernel::K, B::T; check_args::Bool=true + ) where {K,T} @check_args( IntrinsicCoregionMOKernel, - B, - eigmin(B) >= 0, - "B has to be positive semi-definite" + (B, eigmin(B) >= 0, "B has to be positive semi-definite") ) return new{K,T}(kernel, B) end end -function IntrinsicCoregionMOKernel(; kernel::Kernel, B::AbstractMatrix) - return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B) +function IntrinsicCoregionMOKernel(; + kernel::Kernel, B::AbstractMatrix, check_args::Bool=true +) + return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args) end -function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix) - return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B) +function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix; check_args::Bool=true) + return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args) end function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) @@ -57,16 +59,14 @@ function kernelmatrix( return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs) end -if VERSION >= v"1.6" - function kernelmatrix!( - K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI - ) where {MOI<:IsotopicMOInputsUnion} - x.out_dim == y.out_dim || - throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) - Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - Koutputs = _mo_output_covariance(k, x.out_dim) - return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) - end +function kernelmatrix!( + K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI +) where {MOI<:IsotopicMOInputsUnion} + x.out_dim == y.out_dim || + throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) + Kfeatures = kernelmatrix(k.kernel, x.x, y.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) end function Base.show(io::IO, k::IntrinsicCoregionMOKernel) diff --git a/src/mokernels/mokernel.jl b/src/mokernels/mokernel.jl index 30b6c2d9b..474e193ea 100644 --- a/src/mokernels/mokernel.jl +++ b/src/mokernels/mokernel.jl @@ -13,16 +13,14 @@ function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByOutputs}, Kfeatures return kron(Koutputs, Kfeatures) end -if VERSION >= v"1.6" - function _kernelmatrix_kron_helper!( - K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs - ) - return kron!(K, Kfeatures, Koutputs) - end +function _kernelmatrix_kron_helper!( + K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs +) + return kron!(K, Kfeatures, Koutputs) +end - function _kernelmatrix_kron_helper!( - K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs - ) - return kron!(K, Koutputs, Kfeatures) - end +function _kernelmatrix_kron_helper!( + K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs +) + return kron!(K, Koutputs, Kfeatures) end diff --git a/src/utils.jl b/src/utils.jl index bb805b15f..7d01f668f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,24 +1,94 @@ -# Macro for checking arguments -macro check_args(K, param, cond, desc=string(cond)) - quote - if !($(esc(cond))) - throw( - ArgumentError( - string( - $(string(K)), - ": ", - $(string(param)), - " = ", - $(esc(param)), - " does not ", - "satisfy the constraint ", - $(string(desc)), - ".", - ), - ), - ) +# strip_linenos and check_args both come from Distributions.jl, +# see https://github.com/JuliaStats/Distributions.jl/blob/master/src/utils.jl + +macro strip_linenos(expr) + return esc(Base.remove_linenums!(expr)) +end + +""" + @check_args( + K, + @setup(statements...), + (arg₁, cond₁, message₁), + (cond₂, message₂), + ..., + ) + +A convenience macro that generates checks of arguments for a kernel of type `K`. + +The macro expects that a boolean variable of name `check_args` is defined and generates +the following Julia code: +```julia +KernelFunctions.check_args(check_args) do + \$(statements...) + cond₁ || throw(DomainError(arg₁, \$(string(D, ": ", message₁)))) + cond₂ || throw(ArgumentError(\$(string(D, ": ", message₂)))) + ... +end +``` + +The `@setup` argument can be elided if no setup code is needed. Moreover, error messages +can be omitted. In this case the message `"the condition \$(cond) is not satisfied."` is +used. +""" +macro check_args(D, setup_or_check, checks...) + # Extract setup statements + if Meta.isexpr(setup_or_check, :macrocall) && setup_or_check.args[1] == Symbol("@setup") + setup_stmts = Any[esc(ex) for ex in setup_or_check.args[3:end]] + else + setup_stmts = [] + checks = (setup_or_check, checks...) + end + + # Generate expressions for each condition + conds_exprs = map(checks) do check + if Meta.isexpr(check, :tuple, 3) + # argument, condition, and message specified + arg = check.args[1] + cond = check.args[2] + message = string(D, ": ", check.args[3]) + return :(($(esc(cond))) || throw(DomainError($(esc(arg)), $message))) + elseif Meta.isexpr(check, :tuple, 2) + cond_or_message = check.args[2] + if cond_or_message isa String + # only condition and message specified + cond = check.args[1] + message = string(D, ": ", cond_or_message) + return :(($(esc(cond))) || throw(ArgumentError($message))) + else + # only argument and condition specified + arg = check.args[1] + cond = cond_or_message + message = string(D, ": the condition ", cond, " is not satisfied.") + return :(($(esc(cond))) || throw(DomainError($(esc(arg)), $message))) + end + else + # only condition specified + cond = check + message = string(D, ": the condition ", cond, " is not satisfied.") + return :(($(esc(cond))) || throw(ArgumentError($message))) end end + + return @strip_linenos quote + KernelFunctions.check_args($(esc(:check_args))) do + $(__source__) + $(setup_stmts...) + $(conds_exprs...) + end + end +end + +""" + check_args(f, check::Bool) + +Perform check of arguments by calling a function `f`. + +If `check` is `false`, the checks are skipped. +""" +function check_args(f::F, check::Bool) where {F} + check && f() + return nothing end function deprecated_obsdim(obsdim::Union{Int,Nothing}) diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 80a3b1bb8..4c9104288 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -14,7 +14,7 @@ @test repr(k) == "Linear Kernel (c = 0.0)" # Errors. - @test_throws ArgumentError LinearKernel(; c=-0.5) + @test_throws DomainError LinearKernel(; c=-0.5) # Standardised tests. TestUtils.test_interface(k, Float64) @@ -36,8 +36,8 @@ @test metric(PolynomialKernel(; degree=3, c=c)) == KernelFunctions.DotProduct() # Errors. - @test_throws ArgumentError PolynomialKernel(; degree=0) - @test_throws ArgumentError PolynomialKernel(; c=-0.5) + @test_throws DomainError PolynomialKernel(; degree=0) + @test_throws DomainError PolynomialKernel(; c=-0.5) # Standardised tests. TestUtils.test_interface(k, Float64) diff --git a/test/basekernels/wiener.jl b/test/basekernels/wiener.jl index 9dd60ba43..6d8c16c9a 100644 --- a/test/basekernels/wiener.jl +++ b/test/basekernels/wiener.jl @@ -14,8 +14,8 @@ k3 = WienerKernel(; i=3) @test k3 isa WienerKernel{3} - @test_throws ArgumentError WienerKernel(; i=4) - @test_throws ArgumentError WienerKernel(; i=-2) + @test_throws DomainError WienerKernel(; i=4) + @test_throws DomainError WienerKernel(; i=-2) v1 = rand(4) v2 = rand(4) diff --git a/test/chainrules.jl b/test/chainrules.jl index 03c2c3b1f..4edd532d0 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -21,11 +21,7 @@ compare_gradient(:Zygote, [x, y]) do xy KernelFunctions.Sinus(r)(xy[1], xy[2]) end - if VERSION < v"1.6" - @test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6" - else - compare_gradient(:Zygote, [Q, x, y]) do Qxy - SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) - end + compare_gradient(:Zygote, [Q, x, y]) do Qxy + SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) end end