diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 7d5bf460f..0666172b4 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -33,7 +33,9 @@ end @adjoint! function copyto!(b::Buffer, src::AbstractArray) function copyto!_buffer_array_pullback(_) grad = grad_mut(__context__, b) - return (nothing, copy(grad)) + xs = copy(grad) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing + return (nothing, xs) end copyto!(b, src), copyto!_buffer_array_pullback end @@ -43,6 +45,7 @@ end function copyto!_buffer_broadcast_pullback(_) grad = grad_mut(__context__, b) d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing return (nothing, d.bc) end copyto!(b, xs), copyto!_buffer_broadcast_pullback @@ -53,6 +56,7 @@ function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator function copyto!_buffer_generator_pullback(_) grad = grad_mut(cx, b) _, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs))) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing return (nothing, nothing, dg) end copyto!(b, xs), copyto!_buffer_generator_pullback