Skip to content

Commit ef13026

Browse files
committed
updated tests and outputsize gather
1 parent a2f0961 commit ef13026

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

src/outputsize.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,6 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
169169
end
170170
end
171171

172-
NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...)
172+
function NNlib.gather(src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{<:Nil}) where {Tsrc, Nsrc}
173+
fill(nil, (size(src)[1:Nsrc-1]..., size(idx)...))
174+
end

test/cuda/layers.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
128128

129129
embedding = [Flux.Embedding]
130130
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
131-
gpu_gradtest("Embedding repeated indices", embedding, rand(1:50, 10^6), 50 => 2)
131+
gpu_gradtest("Embedding repeated indices", embedding, rand(1:10, 10^3), 10 => 2)
132132
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
133133
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
134134
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
135135
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
136-
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:50, 10^6), 50), 50 => 2)
136+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:10, 10^3), 10), 10 => 2)
137137

138138
@testset "function layers" begin
139139
x = rand(Float32, 3,3)

test/outputsize.jl

+7
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,10 @@ end
155155
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
156156
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
157157
end
158+
159+
@testset "embedding" begin
160+
m = Embedding(3=>5)
161+
@test outputsize(m, (2,)) == (5, 2)
162+
@test outputsize(m, (2, 3)) == (5, 2, 3)
163+
@test outputsize(m, (2, 3, 4)) == (5, 2, 3, 4)
164+
end

0 commit comments

Comments
 (0)