Skip to content

Commit 05fe340

Browse files
authored
use only() instead of first() (#403)
* use only() instead of first() for 1-"vectors" that were for the benefit of Flux * fix one test that should not have worked as it was * add missing scalar Sinus constructor
1 parent 2d17212 commit 05fe340

18 files changed

+48
-43
lines changed

src/basekernels/constant.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ end
7373

7474
@functor ConstantKernel
7575

76-
kappa::ConstantKernel, x::Real) = first.c) * one(x)
76+
kappa::ConstantKernel, x::Real) = only.c) * one(x)
7777

7878
metric(::ConstantKernel) = Delta()
7979

80-
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first.c), ")")
80+
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only.c), ")")

src/basekernels/exponential.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,14 @@ end
137137

138138
@functor GammaExponentialKernel
139139

140-
kappa::GammaExponentialKernel, d::Real) = exp(-d^first.γ))
140+
kappa::GammaExponentialKernel, d::Real) = exp(-d^only.γ))
141141

142142
metric(k::GammaExponentialKernel) = k.metric
143143

144144
iskroncompatible(::GammaExponentialKernel) = true
145145

146146
function Base.show(io::IO, κ::GammaExponentialKernel)
147147
return print(
148-
io, "Gamma Exponential Kernel (γ = ", first.γ), ", metric = ", κ.metric, ")"
148+
io, "Gamma Exponential Kernel (γ = ", only.γ), ", metric = ", κ.metric, ")"
149149
)
150150
end

src/basekernels/fbm.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
2828
modX = sum(abs2, x)
2929
modY = sum(abs2, y)
3030
modXY = sqeuclidean(x, y)
31-
h = first.h)
31+
h = only.h)
3232
return (modX^h + modY^h - modXY^h) / 2
3333
end
3434

3535
function::FBMKernel)(x::Real, y::Real)
36-
return (abs2(x)^first.h) + abs2(y)^first.h) - abs2(x - y)^first.h)) / 2
36+
return (abs2(x)^only.h) + abs2(y)^only.h) - abs2(x - y)^only.h)) / 2
3737
end
3838

3939
function Base.show(io::IO, κ::FBMKernel)
40-
return print(io, "Fractional Brownian Motion Kernel (h = ", first.h), ")")
40+
return print(io, "Fractional Brownian Motion Kernel (h = ", only.h), ")")
4141
end
4242

4343
_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h) / 2

src/basekernels/matern.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
3434
@functor MaternKernel
3535

3636
@inline function kappa::MaternKernel, d::Real)
37-
result = _matern(first.ν), d)
37+
result = _matern(only.ν), d)
3838
return ifelse(iszero(d), one(result), result)
3939
end
4040

@@ -46,7 +46,7 @@ end
4646
metric(k::MaternKernel) = k.metric
4747

4848
function Base.show(io::IO, κ::MaternKernel)
49-
return print(io, "Matern Kernel (ν = ", first.ν), ", metric = ", κ.metric, ")")
49+
return print(io, "Matern Kernel (ν = ", only.ν), ", metric = ", κ.metric, ")")
5050
end
5151

5252
## Matern12Kernel = ExponentialKernel aliased in exponential.jl

src/basekernels/polynomial.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c)
2626

2727
@functor LinearKernel
2828

29-
kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
29+
kappa::LinearKernel, xᵀy::Real) = xᵀy + only.c)
3030

3131
metric(::LinearKernel) = DotProduct()
3232

33-
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first.c), ")")
33+
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only.c), ")")
3434

3535
"""
3636
PolynomialKernel(; degree::Int=2, c::Real=0.0)
@@ -53,7 +53,7 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel
5353

5454
function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
5555
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
56-
@check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0")
56+
@check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0")
5757
return new{Tc}(degree, c)
5858
end
5959
end
@@ -68,10 +68,10 @@ function Functors.functor(::Type{<:PolynomialKernel}, x)
6868
return (c=x.c,), reconstruct_polynomialkernel
6969
end
7070

71-
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + first.c))^κ.degree
71+
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + only.c))^κ.degree
7272

7373
metric(::PolynomialKernel) = DotProduct()
7474

7575
function Base.show(io::IO, κ::PolynomialKernel)
76-
return print(io, "Polynomial Kernel (c = ", first.c), ", degree = ", κ.degree, ")")
76+
return print(io, "Polynomial Kernel (c = ", only.c), ", degree = ", κ.degree, ")")
7777
end

src/basekernels/rational.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ end
3232
@functor RationalKernel
3333

3434
function kappa::RationalKernel, d::Real)
35-
return (one(d) + d / first.α))^(-first.α))
35+
return (one(d) + d / only.α))^(-only.α))
3636
end
3737

3838
metric(k::RationalKernel) = k.metric
3939

4040
function Base.show(io::IO, κ::RationalKernel)
41-
return print(io, "Rational Kernel (α = ", first.α), ", metric = ", κ.metric, ")")
41+
return print(io, "Rational Kernel (α = ", only.α), ", metric = ", κ.metric, ")")
4242
end
4343

4444
"""
@@ -72,18 +72,18 @@ end
7272
@functor RationalQuadraticKernel
7373

7474
function kappa::RationalQuadraticKernel, d::Real)
75-
return (one(d) + d^2 / (2 * first.α)))^(-first.α))
75+
return (one(d) + d^2 / (2 * only.α)))^(-only.α))
7676
end
7777
function kappa::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real)
78-
return (one(d²) +/ (2 * first.α)))^(-first.α))
78+
return (one(d²) +/ (2 * only.α)))^(-only.α))
7979
end
8080

8181
metric(k::RationalQuadraticKernel) = k.metric
8282
metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean()
8383

8484
function Base.show(io::IO, κ::RationalQuadraticKernel)
8585
return print(
86-
io, "Rational Quadratic Kernel (α = ", first.α), ", metric = ", κ.metric, ")"
86+
io, "Rational Quadratic Kernel (α = ", only.α), ", metric = ", κ.metric, ")"
8787
)
8888
end
8989

@@ -122,7 +122,7 @@ end
122122
@functor GammaRationalKernel
123123

124124
function kappa::GammaRationalKernel, d::Real)
125-
return (one(d) + d^first.γ) / first.α))^(-first.α))
125+
return (one(d) + d^only.γ) / only.α))^(-only.α))
126126
end
127127

128128
metric(k::GammaRationalKernel) = k.metric
@@ -131,9 +131,9 @@ function Base.show(io::IO, κ::GammaRationalKernel)
131131
return print(
132132
io,
133133
"Gamma Rational Kernel (α = ",
134-
first.α),
134+
only.α),
135135
", γ = ",
136-
first.γ),
136+
only.γ),
137137
", metric = ",
138138
κ.metric,
139139
")",

src/distances/sinus.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ struct Sinus{T} <: Distances.UnionSemiMetric
22
r::Vector{T}
33
end
44

5+
Sinus(r::Real) = Sinus([r])
6+
57
Distances.parameters(d::Sinus) = d.r
68
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
79
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
8-
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
10+
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / only(dist.r))
911

1012
Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)
1113

src/kernels/scaledkernel.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
@functor ScaledKernel
2525

26-
(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y)
26+
(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y)
2727

2828
function kernelmatrix::ScaledKernel, x::AbstractVector, y::AbstractVector)
2929
return κ.σ² .* kernelmatrix.kernel, x, y)
@@ -75,5 +75,5 @@ Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)
7575

7676
function printshifted(io::IO, κ::ScaledKernel, shift::Int)
7777
printshifted(io, κ.kernel, shift)
78-
return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(first.σ²))")
78+
return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(only.σ²))")
7979
end

src/kernels/transformedkernel.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(
2828
end
2929

3030
function _scale(t::ScaleTransform, metric::Euclidean, x, y)
31-
return first(t.s) * evaluate(metric, x, y)
31+
return only(t.s) * evaluate(metric, x, y)
3232
end
3333
function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
34-
return first(t.s)^2 * evaluate(metric, x, y)
34+
return only(t.s)^2 * evaluate(metric, x, y)
3535
end
3636
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
3737

src/transform/ardtransform.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
dim(t::ARDTransform) = length(t.v)
3434

35-
(t::ARDTransform)(x::Real) = first(t.v) * x
35+
(t::ARDTransform)(x::Real) = only(t.v) * x
3636
(t::ARDTransform)(x) = t.v .* x
3737

3838
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x

src/transform/periodic_transform.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ PeriodicTransform(f::Real) = PeriodicTransform([f])
2525

2626
dim(t::PeriodicTransform) = 2
2727

28-
(t::PeriodicTransform)(x::Real) = [sinpi(2 * first(t.f) * x), cospi(2 * first(t.f) * x)]
28+
(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]
2929

3030
function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
31-
return RowVecs(hcat(sinpi.((2 * first(t.f)) .* x), cospi.((2 * first(t.f)) .* x)))
31+
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
3232
end
3333

3434
function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform)
35-
return isequal(first(t1.f), first(t2.f))
35+
return isequal(only(t1.f), only(t2.f))
3636
end
3737

3838
function Base.show(io::IO, t::PeriodicTransform)
39-
return print(io, "Periodic Transform with frequency $(first(t.f))")
39+
return print(io, "Periodic Transform with frequency $(only(t.f))")
4040
end

src/transform/scaletransform.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ end
2424

2525
set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]
2626

27-
(t::ScaleTransform)(x) = first(t.s) * x
27+
(t::ScaleTransform)(x) = only(t.s) * x
2828

29-
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
30-
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
31-
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
29+
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
30+
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
31+
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
3232

33-
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(first(t.s), first(t2.s))
33+
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))
3434

35-
Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", first(t.s), ")")
35+
Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")")

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
3+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
34
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
45
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
56
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"

test/basekernels/constant.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@
3636

3737
# Standardised tests.
3838
TestUtils.test_interface(k, Float64)
39-
test_ADs(c -> ConstantKernel(; c=first(c)), [c])
39+
test_ADs(c -> ConstantKernel(; c=only(c)), [c])
4040
end
4141
end

test/basekernels/exponential.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
@test metric(k2) isa WeightedEuclidean
5757
@test k2(v1, v2) k(v1, v2)
5858

59-
test_ADs-> GammaExponentialKernel(; gamma=first(γ)), [1 + 0.5 * rand()])
59+
test_ADs-> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()])
6060
test_params(k, ([γ],))
6161
TestUtils.test_interface(GammaExponentialKernel(; γ=1.36))
6262

test/distances/sinus.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
d = KernelFunctions.Sinus(p)
66
@test Distances.parameters(d) == p
77
@test evaluate(d, A, B) == sum(abs2.(sinpi.(A - B) ./ p))
8-
@test d(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
8+
d1 = KernelFunctions.Sinus(first(p))
9+
@test d1(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
910
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using Zygote: Zygote
1414
using ForwardDiff: ForwardDiff
1515
using ReverseDiff: ReverseDiff
1616
using FiniteDifferences: FiniteDifferences
17+
using Compat: only
1718

1819
using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils
1920

test/test_utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ const FDM = FiniteDifferences.central_fdm(5, 1)
4545
gradient(f, s::Symbol, args) = gradient(f, Val(s), args)
4646

4747
function gradient(f, ::Val{:Zygote}, args)
48-
g = first(Zygote.gradient(f, args))
48+
g = only(Zygote.gradient(f, args))
4949
if isnothing(g)
5050
if args isa AbstractArray{<:Real}
5151
return zeros(size(args)) # To respect the same output as other ADs
@@ -66,7 +66,7 @@ function gradient(f, ::Val{:ReverseDiff}, args)
6666
end
6767

6868
function gradient(f, ::Val{:FiniteDiff}, args)
69-
return first(FiniteDifferences.grad(FDM, f, args))
69+
return only(FiniteDifferences.grad(FDM, f, args))
7070
end
7171

7272
function compare_gradient(f, ::Val{:FiniteDiff}, args)

0 commit comments

Comments
 (0)