Skip to content

Commit

Permalink
buffer broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Jan 2, 2024
1 parent 2bf7fe7 commit e5e5354
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
32 changes: 14 additions & 18 deletions src/lib/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@ end
end
end

@adjoint! function copyto!(b::Buffer, xs::AbstractArray)
copyto!(b, xs), function (_)
grad = grad_mut(__context__, b)
x̄s = copy(grad)
grad .= eltype(grad) <: Number ? 0 : nothing
return (nothing, x̄s)
end
end

@adjoint! function copyto!(b::Buffer, x::Number)
copyto!(b, x), function (_)
grad = grad_mut(__context__, b)
return (nothing, sum(grad))
end
end

@adjoint! function push!(b::Buffer, x)
push!(b, x), function (y)
Expand All @@ -51,9 +36,6 @@ end
end
end

function _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x)
_pullback(cx, copyto!, b, x)
end

@adjoint function copy(b::Buffer)
res = copy(b)
Expand All @@ -70,3 +52,17 @@ end

return res, copy_sensitivity
end

Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A)

@non_differentiable Base.Broadcast.Broadcasted(::Nothing)

function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, bc::Base.Broadcast.Broadcasted)
xs, map_pullback = ∇map(cx, i -> bc[i], eachindex(bc))
copyto!(b, xs), function (_)
grad = grad_mut(cx, b)
# ys = copy(grad)
d, = map_pullback(reshape(first(grad, length(xs)), size(xs)))
return (nothing, nothing, d.bc)
end
end
6 changes: 6 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,12 @@ using Zygote: Buffer
return sum(copy(b))
end == (nothing,)

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
b .= (p*i for i in eachindex(b))
return sum(copy(b) .* (2:4))
end[1] 1*2 + 2*3 + 3*4

@test gradient(2) do x
b = Zygote.Buffer([])
push!(b, x)
Expand Down

0 comments on commit e5e5354

Please sign in to comment.