Skip to content

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Aug 25, 2021

@mzgubic implemented zygote2differential as a better version of wrap_chainrules_inputs and added it to use in the code for rrule_via_ad.
But it was not added to the normal path for when Zygote uses ChainRules.
I guess because it requires keeping the primal values in memory.
Which is probably a lot?

Anyway this would give us more consistent chainrules types.
No more Tangent{Any} or nothings that are hidden with-in arrays.

We probably do not want to merge this as is because of the extra memory use.
or maybe it is not too bad. Do we have a benchmark for it?

But hopefully this will fix the problems in TuringLang/DistributionsAD.jl#197
cc @devmotion .
If it does we can look at reworking zygote2differential to not have to store so much.
We learnt a lot about doing that for ProjectTo
same techniques can be applied here.

NB: I am putting this PR up at 9:30 at night, and I have not even run it locally.
Might have typos etc and just not work.
It also has no tests, yet.

@oxinabox
Copy link
Member Author

well this breaks some tests in weird ways.
Clearly some edge cases that zygote2differential doesn't yet cover.
Still might work for TuringLang/DistributionsAD.jl#197
and we can fix those other things

@devmotion
Copy link
Collaborator

I just checked TuringLang/DistributionsAD.jl#198 (the CR1 version) locally and it still fails with the same error messages ("adjoint for constructor ..."), even with this PR.

@oxinabox
Copy link
Member Author

Yeah won't fix that.
I meant fixing
TuringLang/DistributionsAD.jl#197 (comment)

@oxinabox
Copy link
Member Author

There was a matching differential2zygote that@mzgubic wrote.
Which might fix that, if the underlying cause was that the Tangent escaped by hiding in an arrays.

@devmotion
Copy link
Collaborator

Ah sorry, I misunderstood your comment. Unfortunately, the example is not fixed either.

@mzgubic
Copy link
Collaborator

mzgubic commented Aug 26, 2021

Here it is, in case you find it useful:

differential2legacy(x) = unthunk(x) # TODO eventually remove this
differential2legacy(::AbstractZero) = nothing
differential2legacy(t::Union{Tuple, NamedTuple}) = map(differential2legacy, t)
differential2legacy(::Nothing) = (legacytype_warn(Nothing); return nothing)
differential2legacy(a::AbstractArray) = differential2legacy.(a) # TODO: what to do with arrays with nothing?
differential2legacy(a::AbstractArray{<:Number}) = a
for T_outer in (:Tuple, :NamedTuple)
  # we create separate methods rather than using a `Union` + an `if` so that we avoid a
  # branch that changes output type, because nested AD on that kinda thing makes Zygote less
  # than happy.
  @eval @inline function differential2legacy(x::Composite{P, T}) where {P, T<:$T_outer}
    xp = map(differential2legacy, canonicalize(x))
    convert($T_outer, xp)
  end
end

I do recall getting into some kind of trouble when using this instead of wrap_chainrules_outputs. Don't think I've solved it though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants