Skip to content

Commit 2670684

Browse files
fix bug
1 parent 0a75474 commit 2670684

File tree

4 files changed

+7
-18
lines changed

4 files changed

+7
-18
lines changed

src/layers/basic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ end
470470

471471
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
472472
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
473-
(m::Embedding)(x::AbstractArray) = reshape(m(mat(x)), :, size(x)[2:end]...)
473+
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
474474

475475
function Base.show(io::IO, m::Embedding)
476476
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")

src/layers/stateless.jl

+1-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Reshape arbitrarly-shaped input into a matrix-shaped output,
55
preserving the size of the last dimension.
66
7-
See also [`unsqueeze`](@ref) and [`mat`](@ref).
7+
See also [`unsqueeze`](@ref).
88
99
# Examples
1010
```jldoctest
@@ -26,18 +26,6 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29-
"""
30-
mat(x::AbstractArray)
31-
32-
Reshape arbitrarly-shaped input into a matrix-shaped output,
33-
preserving the size of the first dimension.
34-
35-
See also [`flatten`](@ref) and [`unsqueeze`](@ref).
36-
"""
37-
function mat(x::AbstractArray)
38-
return reshape(x, size(x,1), :)
39-
end
40-
4129
"""
4230
normalise(x; dims=ndims(x), ϵ=1e-5)
4331

test/cuda/layers.jl

-3
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ pixelshuffle = [PixelShuffle]
114114
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
115115
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
116116

117-
embedding = [Embedding]
118-
gpu_gradtest("Embedding", embedding, rand(1:10, 3), 10, 4)
119-
120117
@testset "function layers" begin
121118
x = rand(Float32, 3,3)
122119
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)

test/layers/basic.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,17 @@ import Flux: activations
201201
y = m(x)
202202
@test y isa Matrix{Float32}
203203
@test y m.weight[:,x]
204-
205204
x2 = OneHotMatrix(x, vocab_size)
206205
y2 = m(x2)
207206
@test y2 isa Matrix{Float32}
208207
@test y2 y
209208
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
210209

210+
x = rand(1:vocab_size, 3, 4)
211+
y = m(x)
212+
@test y isa Array{Float32, 3}
213+
@test size(y) == (embed_size, 3, 4)
214+
211215
@test m(2) m.weight[:,2]
212216
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
213217
@test_throws DimensionMismatch m(OneHotVector(3, 1000))

0 commit comments

Comments
 (0)