From 7021b6fee596d33f4d7b7cbe7e7f244d6ca4b898 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Sat, 22 Jul 2023 11:10:03 +0300 Subject: [PATCH 1/6] redesign onehot vectors initial commit --- src/array.jl | 217 +++++++++++++++++++++------------------------------ 1 file changed, 90 insertions(+), 127 deletions(-) diff --git a/src/array.jl b/src/array.jl index 6350eb9..2e80ad9 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,158 +1,121 @@ """ - OneHotArray{T, N, M, I} <: AbstractArray{Bool, M} - OneHotArray(indices, L) + OneHotVector{T,L} + OneHotVector(index, L) -A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`) -stored as a compact `N == M-1`-dimensional array of indices. - -Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref). -Parameter `I` is the type of the underlying storage, and `T` its eltype. +A one-hot vector with `L` labels (i.e. `length(A) == L` and `count(A) == 1`). +Stored efficiently as a single index of type `T`, usually `Int32`. """ -struct OneHotArray{T<:Integer, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"} - indices::I - nlabels::Int +struct OneHotVector{T,L} # <: AbstractVector{Bool} # this does everything but is too limiting + index::Integer + function OneHotVector{T,L}(index) where {T,L} + @assert 1 <= index <= L "OneHotVector index $(index) out of range [1,$(L)]" + return new(index) + end end -OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L) -OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L) -OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L) - -_indices(x::OneHotArray) = x.indices -_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = - reshape(parent(x).indices, x.dims[2:end]) - -""" - OneHotVector{T} = OneHotArray{T, 0, 1, T} - OneHotVector(indices, L) - -A one-hot vector with `L` labels (i.e. `length(A) == L` and `count(A) == 1`) typically constructed by [`onehot`](@ref). -Stored efficiently as a single index of type `T`, usually `UInt32`. -""" -const OneHotVector{T} = OneHotArray{T, 0, 1, T} -OneHotVector(idx, L) = OneHotArray(idx, L) - -""" - OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I} - OneHotMatrix(indices, L) -A one-hot matrix (with `L` labels) typically constructed using [`onehotbatch`](@ref). -Stored efficiently as a vector of indices with type `I` and eltype `T`. -""" -const OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I} -OneHotMatrix(indices, L) = OneHotArray(indices, L) - -# use this type so reshaped arrays hit fast paths -# e.g. argmax -const OneHotLike{T, N, var"N+1", I} = - Union{OneHotArray{T, N, var"N+1", I}, - Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, <:Any, <:Any, I}}} - -_isonehot(x::OneHotArray) = true -_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = (size(x, 1) == parent(x).nlabels) - -_check_nlabels(L, xs::OneHotLike...) = all(size.(xs, 1) .== L) - -_nlabels(x::OneHotArray) = size(x, 1) -function _nlabels(x::OneHotLike, xs::OneHotLike...) - L = size(x, 1) - _check_nlabels(L, xs...) || - throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays.")) - - return L +OneHotVector(t::Type, index, nlabels) = OneHotVector{t,nlabels}(index) +OneHotVector(index, nlabels) = OneHotVector{Int32,nlabels}(index) +Base.size(x::OneHotVector{T,L}) where {T,L} = (L,) +function Base.getindex(x::OneHotVector{T,L}, i::Integer) where {T,L} + @boundscheck 1 <= i <= L + i == x.index || throw(BoundsError(x, i)) +end +Base.show(io::IO, x::OneHotVector{T,L}) where {T,L} = Base.show(io, setindex!(zeros(T, L), convert(T, 1), x.index)) +Base.argmax(x::OneHotVector; dims = Colon()) = x.index + +struct OneHotArray{T,N,M,L,A} <: AbstractArray{T,N} + onehotvectors::AbstractArray{OneHotVector{T,L},M} + function OneHotArray(onehotaxis, onehotvectors::AbstractArray{OneHotVector{T,L},M}) where {T,L,M} + N = M+1 + @assert onehotaxis isa Integer "onehot axis must be integer" + @assert 1 <= onehotaxis <= N "onehot axis out of range [1,$N]" + new{T,N,M,L,onehotaxis}(onehotaxis, onehotvectors) + end +end +onehotaxis(x::OneHotArray{T,N,M,L,A}) where {T,N,M,L,A} = A +function size_selector(i, shape, onehotaxis, L) + if i < onehotaxis + shape[i] + elseif i > onehotaxis + shape[i-1] + else + L + end end -Base.size(x::OneHotArray) = (x.nlabels, size(x.indices)...) - -function Base.getindex(x::OneHotArray{<:Any, N}, i::Int, I::Vararg{Int, N}) where N - @boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...))) - return x.indices[I...] .== i +Base.size(x::OneHotArray{T,N,M,L,A}) where {T,N,M,L,A} = ntuple(i -> size_selector(i, size(x.onehotvectors), onehotaxis(x), L), N) +function Base.getindex(x::OneHotArray{T,N,M,L}, i::Integer) where {T,N,M,L} + flat_onehotvectors = x.onehotvectors[:] + h_ind, L_ind = fldmod(i, L) + @boundscheck 1 <= h_ind <= length(flat_onehotvectors) || throw(BoundsError(x, i)) + @boundscheck 1 <= L_ind <= L || throw(BoundsError(x, i)) + return flat_onehotvectors[h_ind].index == L_ind end -# the method above is faster on the CPU but will scalar index on the GPU -# so we define the method below to pass the extra indices directly to GPU array -function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray}, - i::Int, - I::Vararg{Any, N}) where N - @boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...))) - return x.indices[I...] .== i +function Base.getindex(x::OneHotArray{T,N,M,L}, i::Vararg{Integer,N}) where {T,N,M,L} + @boundscheck all(1 .<= i .<= size(x)) || throw(BoundsError(x, (i...))) + index_pre = i[1:onehotaxis(x)-1] + index_post = i[onehotaxis(x)+1:end] + intern_ind = x.onehotvectors[index_pre..., index_post...].index + return convert(T, intern_ind == i[onehotaxis(x)]) end -function Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, I::Vararg{Any, N}) where N - return OneHotArray(x.indices[I...], x.nlabels) + +function Base.show(io::IO, x::OneHotArray{T,N,M,L}) where {T,N,M,L} + z = zeros(T, size(x)) + # loop efficiently over only the ones + for ext_ind in eachindex(IndexCartesian(), x.onehotvectors) + index_pre = Tuple(ext_ind)[1:onehotaxis(x)] + index_post = Tuple(ext_ind)[onehotaxis(x)+1:end] + intern_ind = x.onehotvectors[ext_ind].index + ind = CartesianIndex(CartesianIndex(index_pre..., intern_ind, index_post...)) + setindex!(z, convert(T, 1), ind) + end + Base.show(io, z) end -Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :)) -Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x -function Base.showarg(io::IO, x::OneHotArray, toplevel) - print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(") - Base.showarg(io, x.indices, false) - print(io, ')') - toplevel && print(io, " with eltype Bool") +function Base.showarg(io::IO, x::OneHotArray{T,N,M,L}, toplevel) where {T,N,M,L} + print(io, "$(size(x)) OneHotArray") + toplevel && print(io, " with one hot axis $A and eltype $T") return nothing end # this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots: function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString) - x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s + x[i,j] ? s : Base.replace_with_centered_mark(s) end # copy CuArray versions back before trying to print them: for fun in (:show, :print_array) # print_array is used by 3-arg show @eval begin - Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} = + Base.$fun(io::IO, X::OneHotArray{T,N,M,L, <:AbstractGPUArray}) where {T, N, M,L} = Base.$fun(io, adapt(Array, X)) - Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} = + Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotArray{T,N,M,L,<:AbstractGPUArray}}) where {T, N} = Base.$fun(io, adapt(Array, X)) end end -_onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"} -_onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"} - -_notall_onehot(x::OneHotArray, xs::OneHotArray...) = false -_notall_onehot(x::OneHotLike, xs::OneHotLike...) = any(x -> !_isonehot(x), (x, xs...)) - -function Base.cat(x::OneHotLike{<:Any, <:Any, N}, xs::OneHotLike...; dims::Int) where N - if isone(dims) || _notall_onehot(x, xs...) - return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims) - else - L = _nlabels(x, xs...) - - return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L) - end -end +# Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels) # TODO: edit this -Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2) -Base.vcat(x::OneHotLike, xs::OneHotLike...) = - vcat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...) +# function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray} # TODO: edit this +# # We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need. +# S = Base.BroadcastStyle(T) +# # S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which +# # isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.) +# (typeof(S).name.wrapper){var"N+1"}() +# end -# optimized concatenation for matrices and vectors of same parameters -Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...)) -Base.hcat(x::OneHotVector, xs::OneHotVector...) = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...)) +# Base.map(f, x::OneHotLike) = Base.broadcast(f, x) -if isdefined(Base, :stack) - import Base: _stack -else - import Compat: _stack -end -function _stack(::Colon, xs::AbstractArray{<:OneHotArray}) - n = _nlabels(first(xs)) - all(x -> _nlabels(x)==n, xs) || throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays.")) - OneHotArray(Compat.stack(_indices, xs), n) -end - -Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels) - -function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray} - # We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need. - S = Base.BroadcastStyle(T) - # S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which - # isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.) - (typeof(S).name.wrapper){var"N+1"}() -end +Base.argmax(x::OneHotArray; dims = Colon()) = + dims == onehotaxis(x) ? + argmax.(x.onehotvectors) : + invoke(argmax, Tuple{AbstractArray}, x; dims = dims) -Base.map(f, x::OneHotLike) = Base.broadcast(f, x) +""" + OneHotMatrix{T, L, A} = OneHotArray{T,2,1,L,1} + OneHotMatrix(indices, L) -Base.argmax(x::OneHotLike; dims = Colon()) = - (_isonehot(x) && dims == 1) ? - reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : - invoke(argmax, Tuple{AbstractArray}, x; dims = dims) +A one-hot matrix (with `L` labels) typically constructed using [`onehotbatch`](@ref). +Stored efficiently as a vector of indices with type `I` and eltype `T`. +""" +const OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I} +OneHotMatrix(indices, L) = OneHotArray{T,2,1,L,1}([OneHotVector{T,L}(index) for index in indices]) \ No newline at end of file From 3a9c41d9d23df94379229cce4a1cb2419b89f390 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Sat, 22 Jul 2023 17:21:46 +0300 Subject: [PATCH 2/6] adapt onehot --- src/onehot.jl | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index d2d5e9d..01a6d48 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -31,13 +31,13 @@ julia> hcat(αβγ...) # preserves sparsity function onehot(x, labels) i = _findval(x, labels) isnothing(i) && error("Value $x is not in labels") - OneHotVector{UInt32}(i, length(labels)) + OneHotVector(i, length(labels)) end function onehot(x, labels, default) i = _findval(x, labels) isnothing(i) && return onehot(default, labels) - OneHotVector{UInt32}(i, length(labels)) + OneHotVector(i, length(labels)) end _findval(val, labels) = findfirst(isequal(val), labels) @@ -90,14 +90,16 @@ function _onehotbatch(data, labels) isnothing(_findval(x, labels)) && error("Value $x not found in labels") end end - return OneHotArray(indices, length(labels)) + L = length(labels) + return OneHotArray(1, (index -> OneHotVector(Int32, index, L)).(indices)) end function _onehotbatch(data, labels, default) default_index = _findval(default, labels) isnothing(default_index) && error("Default value $default is not in labels") indices = UInt32[something(_findval(i, labels), default_index) for i in data] - return OneHotArray(indices, length(labels)) + L = length(labels) + return OneHotArray(1, (index -> OneHotVector(index, L)).(indices)) end function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) @@ -106,7 +108,8 @@ function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{< hi > last(labels) && error("Value $hi not found in labels") offset = 1 - first(labels) indices = UInt32.(data .+ offset) - return OneHotArray(indices, length(labels)) + L = length(labels) + return OneHotArray(1, (index -> OneHotVector(index, L)).(indices)) end # That bounds check with extrema synchronises on GPU, much slower than rest of the function, # hence add a special method, with a less helpful error message: @@ -117,7 +120,8 @@ function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRang checkbounds(labels, i) i end - return OneHotArray(indices, length(labels)) + L = length(labels) + return OneHotArray(1, (index -> OneHotVector(index, L)).(indices)) end """ @@ -165,14 +169,7 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1)) end _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) -_fast_argmax(x::OneHotArray) = _indices(x) -function _fast_argmax(x::OneHotLike) - if _isonehot(x) - return _indices(x) - else - return _fast_argmax(convert(_onehot_bool_type(x), x)) - end -end +_fast_argmax(x::Union{OneHotVector, OneHotArray}) = argmax(x) ChainRulesCore.@non_differentiable onehot(::Any...) ChainRulesCore.@non_differentiable onehotbatch(::Any...) From c58117dfb2d550ca57b5eaa026bdc3fb60c99dd9 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Sat, 22 Jul 2023 17:22:01 +0300 Subject: [PATCH 3/6] adapt linalg --- src/linalg.jl | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 07896e2..0ab0598 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,13 +1,13 @@ -function Base.:(*)(A::AbstractMatrix, B::OneHotLike) - _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) +function Base.:(*)(A::AbstractMatrix, B::OneHotArray) + onehotaxis(B) == 1 || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))")) - return A[:, onecold(B)] + return A[:, argmax(B;dims=1)] end -function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1}) - _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) +function Base.:(*)(A::AbstractMatrix, B::OneHotArray{T,2,1,L,Ax}) where {T,L,Ax} + onehotaxis(B) == 1 || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))")) - return NNlib.gather(A, _indices(B)) + return NNlib.gather(A, [v.index for v in B.onehotvectors]) end function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix}) @@ -18,18 +18,16 @@ end for wrapper in [:Adjoint, :Transpose] @eval begin - function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T - size(A, 2) == length(b) || - throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))")) - - return A[:, onecold(b)] + function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{T,L}) where {T,L} + size(A, 2) == L || + throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(L)")) + return A[:, argmax(b)] end - function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector) where T - size(A, 2) == length(b) || - throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))")) - - return A[onecold(b)] + function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{T,L}) where {T,L} + size(A, 2) == L || + throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(L)")) + return A[argmax(b)] end end end From c6fd9f37c47c55989e32235e8170b18a927bf459 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Sat, 22 Jul 2023 17:22:09 +0300 Subject: [PATCH 4/6] fix dtypes --- src/array.jl | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/array.jl b/src/array.jl index 2e80ad9..d259509 100644 --- a/src/array.jl +++ b/src/array.jl @@ -3,7 +3,6 @@ OneHotVector(index, L) A one-hot vector with `L` labels (i.e. `length(A) == L` and `count(A) == 1`). -Stored efficiently as a single index of type `T`, usually `Int32`. """ struct OneHotVector{T,L} # <: AbstractVector{Bool} # this does everything but is too limiting index::Integer @@ -14,7 +13,7 @@ struct OneHotVector{T,L} # <: AbstractVector{Bool} # this does everything but is end OneHotVector(t::Type, index, nlabels) = OneHotVector{t,nlabels}(index) -OneHotVector(index, nlabels) = OneHotVector{Int32,nlabels}(index) +OneHotVector(index, nlabels) = OneHotVector{Float32,nlabels}(index) Base.size(x::OneHotVector{T,L}) where {T,L} = (L,) function Base.getindex(x::OneHotVector{T,L}, i::Integer) where {T,L} @boundscheck 1 <= i <= L @@ -29,7 +28,7 @@ struct OneHotArray{T,N,M,L,A} <: AbstractArray{T,N} N = M+1 @assert onehotaxis isa Integer "onehot axis must be integer" @assert 1 <= onehotaxis <= N "onehot axis out of range [1,$N]" - new{T,N,M,L,onehotaxis}(onehotaxis, onehotvectors) + new{T,N,M,L,onehotaxis}(onehotvectors) end end onehotaxis(x::OneHotArray{T,N,M,L,A}) where {T,N,M,L,A} = A @@ -60,35 +59,35 @@ function Base.getindex(x::OneHotArray{T,N,M,L}, i::Vararg{Integer,N}) where {T,N end function Base.show(io::IO, x::OneHotArray{T,N,M,L}) where {T,N,M,L} - z = zeros(T, size(x)) + z = zeros(Int32, size(x)) # loop efficiently over only the ones for ext_ind in eachindex(IndexCartesian(), x.onehotvectors) index_pre = Tuple(ext_ind)[1:onehotaxis(x)] index_post = Tuple(ext_ind)[onehotaxis(x)+1:end] intern_ind = x.onehotvectors[ext_ind].index ind = CartesianIndex(CartesianIndex(index_pre..., intern_ind, index_post...)) - setindex!(z, convert(T, 1), ind) + setindex!(z, convert(Int32, 1), ind) end Base.show(io, z) end -function Base.showarg(io::IO, x::OneHotArray{T,N,M,L}, toplevel) where {T,N,M,L} +function Base.showarg(io::IO, x::OneHotArray{T,N,M,L,A}, toplevel) where {T,N,M,L,A} print(io, "$(size(x)) OneHotArray") toplevel && print(io, " with one hot axis $A and eltype $T") return nothing end # this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots: -function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString) - x[i,j] ? s : Base.replace_with_centered_mark(s) +function Base.replace_in_print_matrix(x::OneHotArray, i::Integer, j::Integer, s::AbstractString) + x[i,j] > 0 ? s : Base.replace_with_centered_mark(s) end # copy CuArray versions back before trying to print them: for fun in (:show, :print_array) # print_array is used by 3-arg show @eval begin - Base.$fun(io::IO, X::OneHotArray{T,N,M,L, <:AbstractGPUArray}) where {T, N, M,L} = + Base.$fun(io::IO, X::OneHotArray{T,N,M,L, <:AbstractGPUArray}) where {T,N,M,L} = Base.$fun(io, adapt(Array, X)) - Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotArray{T,N,M,L,<:AbstractGPUArray}}) where {T, N} = + Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotArray{T,N,M,L,<:AbstractGPUArray}}) where {T,N,M,L} = Base.$fun(io, adapt(Array, X)) end end @@ -115,7 +114,6 @@ Base.argmax(x::OneHotArray; dims = Colon()) = OneHotMatrix(indices, L) A one-hot matrix (with `L` labels) typically constructed using [`onehotbatch`](@ref). -Stored efficiently as a vector of indices with type `I` and eltype `T`. """ -const OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I} -OneHotMatrix(indices, L) = OneHotArray{T,2,1,L,1}([OneHotVector{T,L}(index) for index in indices]) \ No newline at end of file +const OneHotMatrix{T, L} = OneHotArray{T, 2,1,L,1} +OneHotMatrix(indices, L) = OneHotArray(1, [OneHotVector(index,L) for index in indices]) \ No newline at end of file From 5051a05eb083393a1644ffea043abbc7364cb946 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Sat, 22 Jul 2023 18:41:58 +0300 Subject: [PATCH 5/6] temporary linalg fix --- src/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/linalg.jl b/src/linalg.jl index 0ab0598..30c3dec 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -10,7 +10,7 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotArray{T,2,1,L,Ax}) where {T,L,Ax} return NNlib.gather(A, [v.index for v in B.onehotvectors]) end -function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix}) +function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotArray}) B_dim = length(_indices(parent(B))) size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim")) return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2))) From d2dbc55d3744f644923d8bce1b20f23171ada36d Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Sat, 22 Jul 2023 18:42:58 +0300 Subject: [PATCH 6/6] add type stable concat --- src/array.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/array.jl b/src/array.jl index d259509..988cf71 100644 --- a/src/array.jl +++ b/src/array.jl @@ -104,6 +104,15 @@ end # Base.map(f, x::OneHotLike) = Base.broadcast(f, x) +function Base.cat(xs::OneHotArray{T,N,M,L,A}...; dims) where {T,N,M,L,A} + if dims != A + onehotvectors_dim = dims < A ? dims : dims - 1 + OneHotArray(A, cat((x.onehotvectors for x in xs)...; dims=onehotvectors_dim)) + else + invoke(cat, AbstractArray, xs...; dims=dims) + end +end + Base.argmax(x::OneHotArray; dims = Colon()) = dims == onehotaxis(x) ? argmax.(x.onehotvectors) :