Skip to content

Commit 1a3fa8b

Browse files
committed
_map -> Base.map
1 parent 0f20455 commit 1a3fa8b

10 files changed

+27
-28
lines changed

src/kernels/transformedkernel.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -82,37 +82,37 @@ end
8282
# Kernel matrix operations
8383

8484
function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
85-
return kernelmatrix_diag!(K, κ.kernel, _map.transform, x))
85+
return kernelmatrix_diag!(K, κ.kernel, map.transform, x))
8686
end
8787

8888
function kernelmatrix_diag!(
8989
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
9090
)
91-
return kernelmatrix_diag!(K, κ.kernel, _map.transform, x), _map.transform, y))
91+
return kernelmatrix_diag!(K, κ.kernel, map.transform, x), map.transform, y))
9292
end
9393

9494
function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
95-
return kernelmatrix!(K, κ.kernel, _map.transform, x))
95+
return kernelmatrix!(K, κ.kernel, map.transform, x))
9696
end
9797

9898
function kernelmatrix!(
9999
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
100100
)
101-
return kernelmatrix!(K, κ.kernel, _map.transform, x), _map.transform, y))
101+
return kernelmatrix!(K, κ.kernel, map.transform, x), map.transform, y))
102102
end
103103

104104
function kernelmatrix_diag::TransformedKernel, x::AbstractVector)
105-
return kernelmatrix_diag.kernel, _map.transform, x))
105+
return kernelmatrix_diag.kernel, map.transform, x))
106106
end
107107

108108
function kernelmatrix_diag::TransformedKernel, x::AbstractVector, y::AbstractVector)
109-
return kernelmatrix_diag.kernel, _map.transform, x), _map.transform, y))
109+
return kernelmatrix_diag.kernel, map.transform, x), map.transform, y))
110110
end
111111

112112
function kernelmatrix::TransformedKernel, x::AbstractVector)
113-
return kernelmatrix.kernel, _map.transform, x))
113+
return kernelmatrix.kernel, map.transform, x))
114114
end
115115

116116
function kernelmatrix::TransformedKernel, x::AbstractVector, y::AbstractVector)
117-
return kernelmatrix.kernel, _map.transform, x), _map.transform, y))
117+
return kernelmatrix.kernel, map.transform, x), map.transform, y))
118118
end

src/transform/ardtransform.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ dim(t::ARDTransform) = length(t.v)
3535
(t::ARDTransform)(x::Real) = only(t.v) * x
3636
(t::ARDTransform)(x) = t.v .* x
3737

38-
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
39-
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
40-
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
38+
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
39+
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
40+
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
4141

4242
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
4343

src/transform/chaintransform.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor
3939

4040
(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)
4141

42-
function _map(t::ChainTransform, x::AbstractVector)
42+
function Base.map(t::ChainTransform, x::AbstractVector)
4343
return foldl((x, t) -> map(t, x), t.transforms; init=x)
4444
end
4545

src/transform/functiontransform.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ end
2323

2424
(t::FunctionTransform)(x) = t.f(x)
2525

26-
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
26+
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
2727

28-
function _map(t::FunctionTransform, x::ColVecs)
28+
function Base.map(t::FunctionTransform, x::ColVecs)
2929
vals = map(axes(x.X, 2)) do i
3030
t.f(view(x.X, :, i))
3131
end
3232
return ColVecs(reduce(hcat, vals))
3333
end
3434

35-
function _map(t::FunctionTransform, x::RowVecs)
35+
function Base.map(t::FunctionTransform, x::RowVecs)
3636
vals = map(axes(x.X, 1)) do i
3737
t.f(view(x.X, i, :))
3838
end

src/transform/lineartransform.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ end
3434
(t::LinearTransform)(x::Real) = vec(t.A * x)
3535
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x
3636

37-
_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
38-
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
39-
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
37+
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
38+
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
39+
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
4040

4141
function Base.show(io::IO, t::LinearTransform)
4242
return print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")

src/transform/periodic_transform.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dim(t::PeriodicTransform) = 2
2727

2828
(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]
2929

30-
function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
30+
function Base.map(t::PeriodicTransform, x::AbstractVector{<:Real})
3131
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
3232
end
3333

src/transform/scaletransform.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]
2626

2727
(t::ScaleTransform)(x) = only(t.s) * x
2828

29-
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
30-
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
31-
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
29+
Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
30+
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
31+
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
3232

3333
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))
3434

src/transform/selecttransform.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ duplicate(t::SelectTransform, θ) = t
2525
_maybe_unwrap(x) = x
2626
_maybe_unwrap(x::AbstractArray{<:Any,0}) = x[]
2727

28-
_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
29-
_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
28+
Base.map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
29+
Base.map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
3030

3131
_wrap(x::AbstractVector{<:Real}, ::Any) = x
3232
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)

src/transform/transform.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ Abstract type defining a transformation of the input.
55
"""
66
abstract type Transform end
77

8-
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
9-
_map(t::Transform, x::AbstractVector) = t.(x)
8+
Base.map(t::Transform, x::AbstractVector) = t.(x)
109

1110
"""
1211
IdentityTransform()
@@ -16,7 +15,7 @@ Transformation that returns exactly the input.
1615
struct IdentityTransform <: Transform end
1716

1817
(t::IdentityTransform)(x) = x
19-
_map(::IdentityTransform, x::AbstractVector) = x
18+
Base.map(::IdentityTransform, x::AbstractVector) = x
2019

2120
### TODO Maybe defining adjoints could help but so far it's not working
2221

test/transform/selecttransform.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
("ColVecs", ColVecs(randn(5, 10))),
118118
("RowVecs", RowVecs(randn(11, 4))),
119119
]
120-
@test KernelFunctions._map(t, x) isa AbstractVector{Float64}
120+
@test map(t, x) isa AbstractVector{Float64}
121121
end
122122
end
123123
end

0 commit comments

Comments
 (0)