Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2398dfc

Browse files
committedFeb 23, 2021
fix bug
1 parent 49e42a7 commit 2398dfc

File tree

4 files changed

+10
-20
lines changed

4 files changed

+10
-20
lines changed
 

‎src/layers/basic.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@ end
120120

121121
@functor Dense
122122

123-
function (a::Dense)(x::Union{AbstractVector, AbstractMatrix})
123+
function (a::Dense)(x::AbstractVecOrMat)
124124
W, b, σ = a.W, a.b, a.σ
125125
return σ.(W*x .+ b)
126126
end
127127

128-
(a::Dense)(x::AbstractArray) = reshape(a(mat(x)), :, size(x)[2:end]...)
128+
(a::Dense)(x::AbstractArray) =
129+
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
129130

130131
function Base.show(io::IO, l::Dense)
131132
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
@@ -418,7 +419,7 @@ end
418419

419420
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
420421
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
421-
(m::Embedding)(x::AbstractArray) = reshape(m(mat(x)), :, size(x)[2:end]...)
422+
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
422423

423424
function Base.show(io::IO, m::Embedding)
424425
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")

‎src/layers/stateless.jl

+1-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Reshape arbitrarly-shaped input into a matrix-shaped output,
55
preserving the size of the last dimension.
66
7-
See also [`unsqueeze`](@ref) and [`mat`](@ref).
7+
See also [`unsqueeze`](@ref).
88
99
# Examples
1010
```jldoctest
@@ -26,18 +26,6 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29-
"""
30-
mat(x::AbstractArray)
31-
32-
Reshape arbitrarly-shaped input into a matrix-shaped output,
33-
preserving the size of the first dimension.
34-
35-
See also [`flatten`](@ref) and [`unsqueeze`](@ref).
36-
"""
37-
function mat(x::AbstractArray)
38-
return reshape(x, size(x,1), :)
39-
end
40-
4129
"""
4230
normalise(x; dims=ndims(x), ϵ=1e-5)
4331

‎test/cuda/layers.jl

-3
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ pixelshuffle = [PixelShuffle]
8787
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
8888
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
8989

90-
embedding = [Embedding]
91-
gpu_gradtest("Embedding", embedding, rand(1:10, 3), 10, 4)
92-
9390
@testset "function layers" begin
9491
x = rand(Float32, 3,3)
9592
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)

‎test/layers/basic.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,17 @@ import Flux: activations
163163
y = m(x)
164164
@test y isa Matrix{Float32}
165165
@test y m.weight[:,x]
166-
167166
x2 = OneHotMatrix(x, vocab_size)
168167
y2 = m(x2)
169168
@test y2 isa Matrix{Float32}
170169
@test y2 y
171170
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
172171

172+
x = rand(1:vocab_size, 3, 4)
173+
y = m(x)
174+
@test y isa Array{Float32, 3}
175+
@test size(y) == (embed_size, 3, 4)
176+
173177
@test m(2) m.weight[:,2]
174178
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
175179
@test_throws DimensionMismatch m(OneHotVector(3, 1000))

0 commit comments

Comments
 (0)
Please sign in to comment.