Skip to content

Commit

Permalink
restore zeroing out grad cache
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Feb 12, 2024
1 parent d99c2ba commit c1384ae
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/lib/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ end
@adjoint! function copyto!(b::Buffer, src::AbstractArray)
function copyto!_buffer_array_pullback(_)
grad = grad_mut(__context__, b)
return (nothing, copy(grad))
xs = copy(grad)
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, xs)
end
copyto!(b, src), copyto!_buffer_array_pullback
end
Expand All @@ -43,6 +45,7 @@ end
function copyto!_buffer_broadcast_pullback(_)
grad = grad_mut(__context__, b)
d, = map_pullback(reshape(first(grad, length(xs)), size(xs)))
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, d.bc)
end
copyto!(b, xs), copyto!_buffer_broadcast_pullback
Expand All @@ -53,6 +56,7 @@ function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator
function copyto!_buffer_generator_pullback(_)
grad = grad_mut(cx, b)
_, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs)))
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, nothing, dg)
end
copyto!(b, xs), copyto!_buffer_generator_pullback
Expand Down

0 comments on commit c1384ae

Please sign in to comment.