Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_map -> Base.map #453

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
_map -> Base.map
st-- committed Apr 14, 2022
commit 1a3fa8bc8082ba46c1b1025bb3dc712487d7021f
16 changes: 8 additions & 8 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
@@ -82,37 +82,37 @@ end
# Kernel matrix operations

function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x))
return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x))
end

function kernelmatrix_diag!(
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y))
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x))
return kernelmatrix!(K, κ.kernel, map(κ.transform, x))
end

function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y))
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x))
return kernelmatrix_diag(κ.kernel, map(κ.transform, x))
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix_diag(κ.kernel, map(κ.transform, x), map(κ.transform, y))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x))
return kernelmatrix(κ.kernel, map(κ.transform, x))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix(κ.kernel, map(κ.transform, x), map(κ.transform, y))
end
6 changes: 3 additions & 3 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
@@ -35,9 +35,9 @@ dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x

_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)

Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)

2 changes: 1 addition & 1 deletion src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function _map(t::ChainTransform, x::AbstractVector)
function Base.map(t::ChainTransform, x::AbstractVector)
return foldl((x, t) -> map(t, x), t.transforms; init=x)
end

6 changes: 3 additions & 3 deletions src/transform/functiontransform.jl
Original file line number Diff line number Diff line change
@@ -23,16 +23,16 @@ end

(t::FunctionTransform)(x) = t.f(x)

_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)

function _map(t::FunctionTransform, x::ColVecs)
function Base.map(t::FunctionTransform, x::ColVecs)
vals = map(axes(x.X, 2)) do i
t.f(view(x.X, :, i))
end
return ColVecs(reduce(hcat, vals))
end

function _map(t::FunctionTransform, x::RowVecs)
function Base.map(t::FunctionTransform, x::RowVecs)
vals = map(axes(x.X, 1)) do i
t.f(view(x.X, i, :))
end
6 changes: 3 additions & 3 deletions src/transform/lineartransform.jl
Original file line number Diff line number Diff line change
@@ -34,9 +34,9 @@ end
(t::LinearTransform)(x::Real) = vec(t.A * x)
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x

_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')

function Base.show(io::IO, t::LinearTransform)
return print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
2 changes: 1 addition & 1 deletion src/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ dim(t::PeriodicTransform) = 2

(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]

function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
function Base.map(t::PeriodicTransform, x::AbstractVector{<:Real})
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
end

6 changes: 3 additions & 3 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
@@ -26,9 +26,9 @@ set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]

(t::ScaleTransform)(x) = only(t.s) * x

_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)

Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))

4 changes: 2 additions & 2 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
@@ -25,8 +25,8 @@ duplicate(t::SelectTransform, θ) = t
_maybe_unwrap(x) = x
_maybe_unwrap(x::AbstractArray{<:Any,0}) = x[]

_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
Base.map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
Base.map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)

_wrap(x::AbstractVector{<:Real}, ::Any) = x
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)
5 changes: 2 additions & 3 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
@@ -5,8 +5,7 @@ Abstract type defining a transformation of the input.
"""
abstract type Transform end

Base.map(t::Transform, x::AbstractVector) = _map(t, x)
_map(t::Transform, x::AbstractVector) = t.(x)
Base.map(t::Transform, x::AbstractVector) = t.(x)

"""
IdentityTransform()
@@ -16,7 +15,7 @@ Transformation that returns exactly the input.
struct IdentityTransform <: Transform end

(t::IdentityTransform)(x) = x
_map(::IdentityTransform, x::AbstractVector) = x
Base.map(::IdentityTransform, x::AbstractVector) = x

### TODO Maybe defining adjoints could help but so far it's not working

2 changes: 1 addition & 1 deletion test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@
("ColVecs", ColVecs(randn(5, 10))),
("RowVecs", RowVecs(randn(11, 4))),
]
@test KernelFunctions._map(t, x) isa AbstractVector{Float64}
@test map(t, x) isa AbstractVector{Float64}
end
end
end