Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit de43bf5

Browse files
committedFeb 13, 2022
add outputsize special case for NNlib.gather
1 parent 383bc02 commit de43bf5

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed
 

‎src/layers/basic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(las
515515

516516

517517
(m::Embedding)(x::Integer) = m.weight[:, x]
518-
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
518+
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
519519
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
520520
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
521521
(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix

‎src/outputsize.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
169169
end
170170
end
171171

172-
(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x))
172+
NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...)

0 commit comments

Comments
 (0)
Please sign in to comment.