From c72f567b37e69303f81ed9e8e3cd3041d9129c18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 11:41:38 +0200 Subject: [PATCH 1/9] Add check_args and drop 1.3 --- Project.toml | 4 ++-- src/basekernels/constant.jl | 4 ++-- src/basekernels/exponential.jl | 10 +++++---- src/basekernels/fbm.jl | 6 +++--- src/basekernels/matern.jl | 8 ++++--- src/basekernels/periodic.jl | 11 +++++++--- src/basekernels/polynomial.jl | 19 ++++++++++------- src/basekernels/rational.jl | 27 ++++++++++++++++-------- src/basekernels/wiener.jl | 5 +++-- src/kernels/scaledkernel.jl | 6 ++++-- src/mokernels/independent.jl | 18 +++++++--------- src/mokernels/intrinsiccoregion.jl | 34 ++++++++++++++++-------------- src/mokernels/mokernel.jl | 20 ++++++++---------- test/chainrules.jl | 8 ++----- 14 files changed, 99 insertions(+), 81 deletions(-) diff --git a/Project.toml b/Project.toml index ef8435a0e..9ac9ac813 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.53" +version = "0.10.54" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -35,4 +35,4 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.32, 0.33" TensorCore = "0.1" ZygoteRules = "0.2" -julia = "1.3" +julia = "1.6" diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index e3f5245fb..977c07a94 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -98,8 +98,8 @@ See also: [`ZeroKernel`](@ref) struct ConstantKernel{Tc<:Real} <: SimpleKernel c::Vector{Tc} - function ConstantKernel(; c::Real=1.0) - @check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0") + function ConstantKernel(; c::Real=1.0, check_args::Bool=true) + check_args && @check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0") return new{typeof(c)}([c]) end end diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index 2061d40f9..0173861bb 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -125,14 +125,16 @@ struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel γ::Vector{Tγ} metric::M - function GammaExponentialKernel(γ::Real, metric) - @check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + function GammaExponentialKernel(γ::Real, metric; check_args::Bool=true) + check_args && @check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") return new{typeof(γ),typeof(metric)}([γ], metric) end end -function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean()) - return GammaExponentialKernel(γ, metric) +function GammaExponentialKernel(; + gamma::Real=1.0, γ::Real=gamma, metric=Euclidean(), check_args::Bool=true +) + return GammaExponentialKernel(γ, metric; check_args) end @functor GammaExponentialKernel diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index 422400ac4..ef7dfe093 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -14,13 +14,13 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}. """ struct FBMKernel{T<:Real} <: Kernel h::Vector{T} - function FBMKernel(h::Real) - @check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]") + function FBMKernel(h::Real; check_args::Bool=true) + check_args && @check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]") return new{typeof(h)}([h]) end end -FBMKernel(; h::Real=0.5) = FBMKernel(h) +FBMKernel(; h::Real=0.5, check_args::Bool=true) = FBMKernel(h; check_args) @functor FBMKernel diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index cb051ee0f..7d74b9853 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -27,13 +27,15 @@ struct MaternKernel{Tν<:Real,M} <: SimpleKernel ν::Vector{Tν} metric::M - function MaternKernel(ν::Real, metric) - @check_args(MaternKernel, ν, ν > zero(ν), "ν > 0") + function MaternKernel(ν::Real, metric; check_args::Bool=true) + check_args && @check_args(MaternKernel, ν, ν > zero(ν), "ν > 0") return new{typeof(ν),typeof(metric)}([ν], metric) end end -MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric) +function MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean(), check_args::Bool=true) + return MaternKernel(ν, metric; check_args) +end @functor MaternKernel diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index 2758d7f94..7b7dd5160 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -15,8 +15,11 @@ k(x, x'; r) = \\exp\\bigg(- \\frac{1}{2} \\sum_{i=1}^d \\bigg(\\frac{\\sin\\big( """ struct PeriodicKernel{T} <: SimpleKernel r::Vector{T} - function PeriodicKernel(; r::AbstractVector{<:Real}=ones(Float64, 1)) - @check_args(PeriodicKernel, r, all(ri > zero(ri) for ri in r), "r > 0") + function PeriodicKernel(; + r::AbstractVector{<:Real}=ones(Float64, 1), check_args::Bool=true + ) + check_args && + @check_args(PeriodicKernel, r, all(ri > zero(ri) for ri in r), "r > 0") return new{eltype(r)}(r) end end @@ -28,7 +31,9 @@ PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims) Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`. """ -PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims)) +function PeriodicKernel(T::DataType, dims::Int=1) + return PeriodicKernel(; r=ones(T, dims), check_args=false) +end @functor PeriodicKernel diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index 0dbd2d52b..938528638 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -16,13 +16,13 @@ See also: [`PolynomialKernel`](@ref) struct LinearKernel{Tc<:Real} <: SimpleKernel c::Vector{Tc} - function LinearKernel(c::Real) - @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0") + function LinearKernel(c::Real; check_args::Bool=true) + check_args && @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0") return new{typeof(c)}([c]) end end -LinearKernel(; c::Real=0.0) = LinearKernel(c) +LinearKernel(; c::Real=0.0, check_args::Bool=true) = LinearKernel(c; check_args) @functor LinearKernel @@ -69,15 +69,18 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel degree::Int c::Vector{Tc} - function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc} - @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") - @check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0") + function PolynomialKernel{Tc}( + degree::Int, c::Vector{Tc}; check_args::Bool=true + ) where {Tc} + check_args && + @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") + check_args && @check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0") return new{Tc}(degree, c) end end -function PolynomialKernel(; degree::Int=2, c::Real=0.0) - return PolynomialKernel{typeof(c)}(degree, [c]) +function PolynomialKernel(; degree::Int=2, c::Real=0.0, check_args::Bool=true) + return PolynomialKernel{typeof(c)}(degree, [c]; check_args) end # The degree of the polynomial kernel is a fixed discrete parameter diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 835fe92ff..c532072a7 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -19,14 +19,16 @@ struct RationalKernel{Tα<:Real,M} <: SimpleKernel α::Vector{Tα} metric::M - function RationalKernel(α::Real, metric) - @check_args(RationalKernel, α, α > zero(α), "α > 0") + function RationalKernel(α::Real, metric; check_args::Bool=true) + check_args && @check_args(RationalKernel, α, α > zero(α), "α > 0") return new{typeof(α),typeof(metric)}([α], metric) end end -function RationalKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) - return RationalKernel(α, metric) +function RationalKernel(; + alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true +) + return RationalKernel(α, metric; check_args) end @functor RationalKernel @@ -83,8 +85,10 @@ struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel α::Vector{Tα} metric::M - function RationalQuadraticKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) - @check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0") + function RationalQuadraticKernel(; + alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true + ) + check_args && @check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0") return new{typeof(α),typeof(metric)}([α], metric) end end @@ -172,10 +176,15 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel metric::M function GammaRationalKernel(; - alpha::Real=2.0, gamma::Real=1.0, α::Real=alpha, γ::Real=gamma, metric=Euclidean() + alpha::Real=2.0, + gamma::Real=1.0, + α::Real=alpha, + γ::Real=gamma, + metric=Euclidean(), + check_args::Bool=true, ) - @check_args(GammaRationalKernel, α, α > zero(α), "α > 0") - @check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + check_args && @check_args(GammaRationalKernel, α, α > zero(α), "α > 0") + check_args && @check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") return new{typeof(α),typeof(γ),typeof(metric)}([α], [γ], metric) end end diff --git a/src/basekernels/wiener.jl b/src/basekernels/wiener.jl index 14d330850..be511fbff 100644 --- a/src/basekernels/wiener.jl +++ b/src/basekernels/wiener.jl @@ -39,8 +39,9 @@ The [`WhiteKernel`](@ref) is recovered for ``i = -1``. [^SDH]: Schober, Duvenaud & Hennig (2014). Probabilistic ODE Solvers with Runge-Kutta Means. """ struct WienerKernel{I} <: Kernel - function WienerKernel{I}() where {I} - @check_args(WienerKernel, I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}") + function WienerKernel{I}(; check_args::Bool=true) where {I} + check_args && + @check_args(WienerKernel, I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}") if I == -1 return WhiteKernel() end diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 0a17943ac..9bfafaa68 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -16,8 +16,10 @@ struct ScaledKernel{Tk<:Kernel,Tσ²<:Real} <: Kernel σ²::Vector{Tσ²} end -function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real} - @check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0") +function ScaledKernel( + kernel::Tk, σ²::Tσ²=1.0; check_args::Bool=true +) where {Tk<:Kernel,Tσ²<:Real} + check_args && @check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0") return ScaledKernel{Tk,Tσ²}(kernel, [σ²]) end diff --git a/src/mokernels/independent.jl b/src/mokernels/independent.jl index 1f7811b14..554fdb08f 100644 --- a/src/mokernels/independent.jl +++ b/src/mokernels/independent.jl @@ -39,16 +39,14 @@ function kernelmatrix( return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs) end -if VERSION >= v"1.6" - function kernelmatrix!( - K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI - ) where {MOI<:IsotopicMOInputsUnion} - x.out_dim == y.out_dim || - throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) - Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - Koutputs = _mo_output_covariance(k, x.out_dim) - return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) - end +function kernelmatrix!( + K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI +) where {MOI<:IsotopicMOInputsUnion} + x.out_dim == y.out_dim || + throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) + Kfeatures = kernelmatrix(k.kernel, x.x, y.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) end function Base.show(io::IO, k::IndependentMOKernel) diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index 0a940796b..0bd17f55c 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -19,8 +19,10 @@ struct IntrinsicCoregionMOKernel{K<:Kernel,T<:AbstractMatrix} <: MOKernel kernel::K B::T - function IntrinsicCoregionMOKernel{K,T}(kernel::K, B::T) where {K,T} - @check_args( + function IntrinsicCoregionMOKernel{K,T}( + kernel::K, B::T; check_args::Bool=true + ) where {K,T} + check_args && @check_args( IntrinsicCoregionMOKernel, B, eigmin(B) >= 0, @@ -30,12 +32,14 @@ struct IntrinsicCoregionMOKernel{K<:Kernel,T<:AbstractMatrix} <: MOKernel end end -function IntrinsicCoregionMOKernel(; kernel::Kernel, B::AbstractMatrix) - return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B) +function IntrinsicCoregionMOKernel(; + kernel::Kernel, B::AbstractMatrix; check_args::Bool=true +) + return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args) end -function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix) - return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B) +function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix; check_args::Bool=true) + return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args) end function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) @@ -57,16 +61,14 @@ function kernelmatrix( return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs) end -if VERSION >= v"1.6" - function kernelmatrix!( - K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI - ) where {MOI<:IsotopicMOInputsUnion} - x.out_dim == y.out_dim || - throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) - Kfeatures = kernelmatrix(k.kernel, x.x, y.x) - Koutputs = _mo_output_covariance(k, x.out_dim) - return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) - end +function kernelmatrix!( + K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI +) where {MOI<:IsotopicMOInputsUnion} + x.out_dim == y.out_dim || + throw(DimensionMismatch("`x` and `y` must have the same `out_dim`")) + Kfeatures = kernelmatrix(k.kernel, x.x, y.x) + Koutputs = _mo_output_covariance(k, x.out_dim) + return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs) end function Base.show(io::IO, k::IntrinsicCoregionMOKernel) diff --git a/src/mokernels/mokernel.jl b/src/mokernels/mokernel.jl index 30b6c2d9b..474e193ea 100644 --- a/src/mokernels/mokernel.jl +++ b/src/mokernels/mokernel.jl @@ -13,16 +13,14 @@ function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByOutputs}, Kfeatures return kron(Koutputs, Kfeatures) end -if VERSION >= v"1.6" - function _kernelmatrix_kron_helper!( - K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs - ) - return kron!(K, Kfeatures, Koutputs) - end +function _kernelmatrix_kron_helper!( + K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs +) + return kron!(K, Kfeatures, Koutputs) +end - function _kernelmatrix_kron_helper!( - K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs - ) - return kron!(K, Koutputs, Kfeatures) - end +function _kernelmatrix_kron_helper!( + K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs +) + return kron!(K, Koutputs, Kfeatures) end diff --git a/test/chainrules.jl b/test/chainrules.jl index 03c2c3b1f..4edd532d0 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -21,11 +21,7 @@ compare_gradient(:Zygote, [x, y]) do xy KernelFunctions.Sinus(r)(xy[1], xy[2]) end - if VERSION < v"1.6" - @test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6" - else - compare_gradient(:Zygote, [Q, x, y]) do Qxy - SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) - end + compare_gradient(:Zygote, [Q, x, y]) do Qxy + SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) end end From 0b7461f0ac310df3fcaf922b60093936c723951c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 11:55:39 +0200 Subject: [PATCH 2/9] remove double semi colon --- src/mokernels/intrinsiccoregion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index 0bd17f55c..9a3e893a3 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -33,7 +33,7 @@ struct IntrinsicCoregionMOKernel{K<:Kernel,T<:AbstractMatrix} <: MOKernel end function IntrinsicCoregionMOKernel(; - kernel::Kernel, B::AbstractMatrix; check_args::Bool=true + kernel::Kernel, B::AbstractMatrix, check_args::Bool=true ) return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args) end From b7d86626182f929fe475f18823fae79ac7a716c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 12:49:27 +0200 Subject: [PATCH 3/9] New @check_args syntax --- src/basekernels/constant.jl | 2 +- src/basekernels/exponential.jl | 2 +- src/basekernels/fbm.jl | 2 +- src/basekernels/matern.jl | 2 +- src/basekernels/periodic.jl | 3 +-- src/basekernels/polynomial.jl | 10 ++++++---- src/basekernels/rational.jl | 11 +++++++---- src/basekernels/wiener.jl | 3 +-- src/kernels/scaledkernel.jl | 2 +- src/mokernels/intrinsiccoregion.jl | 6 ++---- 10 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index 977c07a94..cfbd8abc6 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -99,7 +99,7 @@ struct ConstantKernel{Tc<:Real} <: SimpleKernel c::Vector{Tc} function ConstantKernel(; c::Real=1.0, check_args::Bool=true) - check_args && @check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0") + @check_args(ConstantKernel, (c, c >= zero(c), "c ≥ 0")) return new{typeof(c)}([c]) end end diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index 0173861bb..0fc4b313f 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -126,7 +126,7 @@ struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel metric::M function GammaExponentialKernel(γ::Real, metric; check_args::Bool=true) - check_args && @check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + @check_args(GammaExponentialKernel, (γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")) return new{typeof(γ),typeof(metric)}([γ], metric) end end diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index ef7dfe093..b93933a9a 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -15,7 +15,7 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}. struct FBMKernel{T<:Real} <: Kernel h::Vector{T} function FBMKernel(h::Real; check_args::Bool=true) - check_args && @check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]") + @check_args(FBMKernel, (h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]")) return new{typeof(h)}([h]) end end diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 7d74b9853..e1e7b8486 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -28,7 +28,7 @@ struct MaternKernel{Tν<:Real,M} <: SimpleKernel metric::M function MaternKernel(ν::Real, metric; check_args::Bool=true) - check_args && @check_args(MaternKernel, ν, ν > zero(ν), "ν > 0") + @check_args(MaternKernel, (ν, ν > zero(ν), "ν > 0")) return new{typeof(ν),typeof(metric)}([ν], metric) end end diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index 7b7dd5160..d657de90e 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -18,8 +18,7 @@ struct PeriodicKernel{T} <: SimpleKernel function PeriodicKernel(; r::AbstractVector{<:Real}=ones(Float64, 1), check_args::Bool=true ) - check_args && - @check_args(PeriodicKernel, r, all(ri > zero(ri) for ri in r), "r > 0") + @check_args(PeriodicKernel, (r, all(ri > zero(ri) for ri in r), "r > 0")w) return new{eltype(r)}(r) end end diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index 938528638..f2385efa5 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -17,7 +17,7 @@ struct LinearKernel{Tc<:Real} <: SimpleKernel c::Vector{Tc} function LinearKernel(c::Real; check_args::Bool=true) - check_args && @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0") + @check_args(LinearKernel, (c, c >= zero(c), "c ≥ 0")) return new{typeof(c)}([c]) end end @@ -72,9 +72,11 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel function PolynomialKernel{Tc}( degree::Int, c::Vector{Tc}; check_args::Bool=true ) where {Tc} - check_args && - @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") - check_args && @check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0") + @check_args( + PolynomialKernel, + (degree, degree >= one(degree), "degree ≥ 1"), + (c, only(c) >= zero(Tc), "c ≥ 0") + ) return new{Tc}(degree, c) end end diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index c532072a7..4ba0db22d 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -20,7 +20,7 @@ struct RationalKernel{Tα<:Real,M} <: SimpleKernel metric::M function RationalKernel(α::Real, metric; check_args::Bool=true) - check_args && @check_args(RationalKernel, α, α > zero(α), "α > 0") + @check_args(RationalKernel, (α, α > zero(α), "α > 0")) return new{typeof(α),typeof(metric)}([α], metric) end end @@ -88,7 +88,7 @@ struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel function RationalQuadraticKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true ) - check_args && @check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0") + @check_args(RationalQuadraticKernel, (α, α > zero(α), "α > 0")) return new{typeof(α),typeof(metric)}([α], metric) end end @@ -183,8 +183,11 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel metric=Euclidean(), check_args::Bool=true, ) - check_args && @check_args(GammaRationalKernel, α, α > zero(α), "α > 0") - check_args && @check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + @check_args( + GammaRationalKernel, + (α, α > zero(α), "α > 0"), + (γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") + ) return new{typeof(α),typeof(γ),typeof(metric)}([α], [γ], metric) end end diff --git a/src/basekernels/wiener.jl b/src/basekernels/wiener.jl index be511fbff..980b243ab 100644 --- a/src/basekernels/wiener.jl +++ b/src/basekernels/wiener.jl @@ -40,8 +40,7 @@ The [`WhiteKernel`](@ref) is recovered for ``i = -1``. """ struct WienerKernel{I} <: Kernel function WienerKernel{I}(; check_args::Bool=true) where {I} - check_args && - @check_args(WienerKernel, I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}") + @check_args(WienerKernel, (I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}")) if I == -1 return WhiteKernel() end diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 9bfafaa68..6b20c6588 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -19,7 +19,7 @@ end function ScaledKernel( kernel::Tk, σ²::Tσ²=1.0; check_args::Bool=true ) where {Tk<:Kernel,Tσ²<:Real} - check_args && @check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0") + @check_args(ScaledKernel, (σ², σ² > zero(Tσ²), "σ² > 0")) return ScaledKernel{Tk,Tσ²}(kernel, [σ²]) end diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index 9a3e893a3..cfa1118f0 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -22,11 +22,9 @@ struct IntrinsicCoregionMOKernel{K<:Kernel,T<:AbstractMatrix} <: MOKernel function IntrinsicCoregionMOKernel{K,T}( kernel::K, B::T; check_args::Bool=true ) where {K,T} - check_args && @check_args( + @check_args( IntrinsicCoregionMOKernel, - B, - eigmin(B) >= 0, - "B has to be positive semi-definite" + (B, eigmin(B) >= 0, "B has to be positive semi-definite") ) return new{K,T}(kernel, B) end From 505ec302167595259cadc0d95abc13050de09534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 12:51:42 +0200 Subject: [PATCH 4/9] Copy from Distributions.jl --- src/utils.jl | 109 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 19 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index bb805b15f..78177297b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,24 +1,95 @@ -# Macro for checking arguments -macro check_args(K, param, cond, desc=string(cond)) - quote - if !($(esc(cond))) - throw( - ArgumentError( - string( - $(string(K)), - ": ", - $(string(param)), - " = ", - $(esc(param)), - " does not ", - "satisfy the constraint ", - $(string(desc)), - ".", - ), - ), - ) +# strip_linenos and check_args both come from Distributions.jl, +# see https://github.com/JuliaStats/Distributions.jl/blob/master/src/utils.jl + + +macro strip_linenos(expr) + return esc(Base.remove_linenums!(expr)) +end + +""" + @check_args( + D, + @setup(statements...), + (arg₁, cond₁, message₁), + (cond₂, message₂), + ..., + ) + +A convenience macro that generates checks of arguments for a distribution of type `D`. + +The macro expects that a boolean variable of name `check_args` is defined and generates +the following Julia code: +```julia +Distributions.check_args(check_args) do + \$(statements...) + cond₁ || throw(DomainError(arg₁, \$(string(D, ": ", message₁)))) + cond₂ || throw(ArgumentError(\$(string(D, ": ", message₂)))) + ... +end +``` + +The `@setup` argument can be elided if no setup code is needed. Moreover, error messages +can be omitted. In this case the message `"the condition \$(cond) is not satisfied."` is +used. +""" +macro check_args(D, setup_or_check, checks...) + # Extract setup statements + if Meta.isexpr(setup_or_check, :macrocall) && setup_or_check.args[1] == Symbol("@setup") + setup_stmts = Any[esc(ex) for ex in setup_or_check.args[3:end]] + else + setup_stmts = [] + checks = (setup_or_check, checks...) + end + + # Generate expressions for each condition + conds_exprs = map(checks) do check + if Meta.isexpr(check, :tuple, 3) + # argument, condition, and message specified + arg = check.args[1] + cond = check.args[2] + message = string(D, ": ", check.args[3]) + return :(($(esc(cond))) || throw(DomainError($(esc(arg)), $message))) + elseif Meta.isexpr(check, :tuple, 2) + cond_or_message = check.args[2] + if cond_or_message isa String + # only condition and message specified + cond = check.args[1] + message = string(D, ": ", cond_or_message) + return :(($(esc(cond))) || throw(ArgumentError($message))) + else + # only argument and condition specified + arg = check.args[1] + cond = cond_or_message + message = string(D, ": the condition ", cond, " is not satisfied.") + return :(($(esc(cond))) || throw(DomainError($(esc(arg)), $message))) + end + else + # only condition specified + cond = check + message = string(D, ": the condition ", cond, " is not satisfied.") + return :(($(esc(cond))) || throw(ArgumentError($message))) end end + + return @strip_linenos quote + Distributions.check_args($(esc(:check_args))) do + $(__source__) + $(setup_stmts...) + $(conds_exprs...) + end + end +end + +""" + check_args(f, check::Bool) + +Perform check of arguments by calling a function `f`. + +If `check` is `false`, the checks are skipped. +""" +function check_args(f::F, check::Bool) where {F} + check && f() + nothing end function deprecated_obsdim(obsdim::Union{Int,Nothing}) From 7a3fccb05c02c6b9f8309a6f6f434b4774748184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 12:54:36 +0200 Subject: [PATCH 5/9] Leftover Distributions --- src/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 78177297b..da760777a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,19 +8,19 @@ end """ @check_args( - D, + K, @setup(statements...), (arg₁, cond₁, message₁), (cond₂, message₂), ..., ) -A convenience macro that generates checks of arguments for a distribution of type `D`. +A convenience macro that generates checks of arguments for a kernel of type `K`. The macro expects that a boolean variable of name `check_args` is defined and generates the following Julia code: ```julia -Distributions.check_args(check_args) do +KernelFunctions.check_args(check_args) do \$(statements...) cond₁ || throw(DomainError(arg₁, \$(string(D, ": ", message₁)))) cond₂ || throw(ArgumentError(\$(string(D, ": ", message₂)))) @@ -72,7 +72,7 @@ macro check_args(D, setup_or_check, checks...) end return @strip_linenos quote - Distributions.check_args($(esc(:check_args))) do + KernelFunctions.check_args($(esc(:check_args))) do $(__source__) $(setup_stmts...) $(conds_exprs...) From 0e5b6358ba793fd8dd40e6cb85d227ff3b800c34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 13:01:45 +0200 Subject: [PATCH 6/9] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index da760777a..c68ad70dd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,7 +1,6 @@ # strip_linenos and check_args both come from Distributions.jl, # see https://github.com/JuliaStats/Distributions.jl/blob/master/src/utils.jl - macro strip_linenos(expr) return esc(Base.remove_linenums!(expr)) end From b9dd6df153c2e849b5706186dd9059ebded0e355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 13:01:50 +0200 Subject: [PATCH 7/9] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index c68ad70dd..7d01f668f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -88,7 +88,7 @@ If `check` is `false`, the checks are skipped. """ function check_args(f::F, check::Bool) where {F} check && f() - nothing + return nothing end function deprecated_obsdim(obsdim::Union{Int,Nothing}) From 4f247184306ac9012d76fb9d3e21ebcab645e718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 13:36:18 +0200 Subject: [PATCH 8/9] Typo --- src/basekernels/periodic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index d657de90e..8729e9606 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -18,7 +18,7 @@ struct PeriodicKernel{T} <: SimpleKernel function PeriodicKernel(; r::AbstractVector{<:Real}=ones(Float64, 1), check_args::Bool=true ) - @check_args(PeriodicKernel, (r, all(ri > zero(ri) for ri in r), "r > 0")w) + @check_args(PeriodicKernel, (r, all(ri > zero(ri) for ri in r), "r > 0")) return new{eltype(r)}(r) end end From 0c8185ac1adb955bd8211d459e390c26d769a3bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 18 Apr 2023 13:52:04 +0200 Subject: [PATCH 9/9] Adapt error tests --- test/basekernels/polynomial.jl | 6 +++--- test/basekernels/wiener.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 80a3b1bb8..4c9104288 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -14,7 +14,7 @@ @test repr(k) == "Linear Kernel (c = 0.0)" # Errors. - @test_throws ArgumentError LinearKernel(; c=-0.5) + @test_throws DomainError LinearKernel(; c=-0.5) # Standardised tests. TestUtils.test_interface(k, Float64) @@ -36,8 +36,8 @@ @test metric(PolynomialKernel(; degree=3, c=c)) == KernelFunctions.DotProduct() # Errors. - @test_throws ArgumentError PolynomialKernel(; degree=0) - @test_throws ArgumentError PolynomialKernel(; c=-0.5) + @test_throws DomainError PolynomialKernel(; degree=0) + @test_throws DomainError PolynomialKernel(; c=-0.5) # Standardised tests. TestUtils.test_interface(k, Float64) diff --git a/test/basekernels/wiener.jl b/test/basekernels/wiener.jl index 9dd60ba43..6d8c16c9a 100644 --- a/test/basekernels/wiener.jl +++ b/test/basekernels/wiener.jl @@ -14,8 +14,8 @@ k3 = WienerKernel(; i=3) @test k3 isa WienerKernel{3} - @test_throws ArgumentError WienerKernel(; i=4) - @test_throws ArgumentError WienerKernel(; i=-2) + @test_throws DomainError WienerKernel(; i=4) + @test_throws DomainError WienerKernel(; i=-2) v1 = rand(4) v2 = rand(4)