Skip to content

Commit

Permalink
Unthunk tangents (if any) before returning gradient (#1551)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jan 21, 2025
1 parent a38a4a5 commit 1b111d8
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Requires = "1.1"
SpecialFunctions = "1.6, 2"
Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.5"
ZygoteRules = "0.2.7"
julia = "1.6"

[extras]
Expand Down
11 changes: 5 additions & 6 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@non_differentiable unthunk_tangent(::IdDict)


Expand Down
4 changes: 2 additions & 2 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
return _project_all(args, grad)
return _project_all(args, unthunk_tangent(grad))
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -218,7 +218,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = _project_all(args, grad)
results = _project_all(args, unthunk_tangent(grad))
(val=y, grad=results)
end

Expand Down
3 changes: 3 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ end
accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)

accum(x::Nothing, y::AbstractThunk) = y
accum(x::AbstractThunk, y::Nothing) = x

accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))
Expand Down
28 changes: 28 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,31 @@ end
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
@test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible
end

@testset "Lazy" begin
custom_add(x, y) = x + y
function ChainRulesCore.rrule(::typeof(custom_add), x, y)
function pullback(Δ)
return NoTangent(), unthunk(Δ), @thunk(error("Should not compute."))
end
custom_add(x, y), pullback
end

x, y = 1f0, 1f0
Zygote.gradient(x) do x
sum(custom_add(x, y))
end
end

@testset "No thunks in the gradient" begin
struct CustomDense
w::Matrix{Float32}
end
(d::CustomDense)(x) = d.w * x

layers = [CustomDense(rand(Float32, 3, 3))]
x = ones(Float32, 3)
g = gradient(layers -> sum(layers[1](x)), layers)[1]
@test g[1] isa NamedTuple
@test g[1].w isa Array
end

0 comments on commit 1b111d8

Please sign in to comment.