Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add support for ChainRules Composite type #806

Open
wants to merge 38 commits into
base: chainrules_types
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8b4826c
do not convert chain rules output to named tuple
Sep 21, 2020
ee9cc0a
using Composite
Sep 30, 2020
407eb34
add the old changes
Oct 9, 2020
07598a0
Update with changes to mz/cr-types
Oct 9, 2020
bac7fa9
remove some changes
Oct 9, 2020
0d65140
add legacytype_warn
Oct 9, 2020
5cb1036
Merge branch 'mz/cr-types' into mz/cr-composite
Oct 9, 2020
dbb7675
add Composite to allowed gradients
Oct 9, 2020
33df078
Merge branch 'mz/cr-types' into mz/cr-composite
Oct 9, 2020
b84d0a9
__new__ and __splatnew__ add Composite support
Oct 9, 2020
a20a6e1
improve readability
Oct 14, 2020
457826b
remove __new__ changes
Oct 16, 2020
a02a649
move to new warnings with types passed
Oct 16, 2020
d4b7e49
fix the iterator
Oct 16, 2020
78a7170
add warnings to chainrules
Oct 17, 2020
a9c0b7a
Composite{Any} -> Composite(typeof(g)}
Oct 21, 2020
09652d9
allowed gradient types change
Oct 21, 2020
6f609e9
remove some legacy2differential instances
Oct 22, 2020
98cd45e
accum_sum and unbroadcast to differential types
Oct 26, 2020
d3bb7fa
remove some gradtuple1
Oct 26, 2020
22f5ae5
fix literal getproperty as _pullback
Oct 27, 2020
8d9810f
fix cholesky adjoints
Oct 28, 2020
2ae657f
fix call overload
Nov 3, 2020
a58fc37
fix Real constructor
Nov 3, 2020
96633d1
change to diffgradtuple
Nov 3, 2020
cae7b62
Core._apply tuples to Composites
Nov 3, 2020
3ec66c9
fix tasks
Nov 5, 2020
33ed358
fix the warnings
Nov 5, 2020
1528485
change the test
Nov 5, 2020
60c3d72
improve type stability
Nov 5, 2020
2040fc1
move to new version of l2d taking a primal rather than primal type
Nov 16, 2020
697792e
fix _back
Nov 17, 2020
48102d8
nested ad fix first draft
Nov 20, 2020
cb13531
fix map a bit
Nov 20, 2020
925b471
add map tests and fix reverse of composite
Nov 24, 2020
453df12
l2d
Nov 24, 2020
db0662f
move getindex to _pullback
Nov 24, 2020
5bdc8b1
Core.getfield to _pullback
Nov 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix literal getproperty as _pullback
  • Loading branch information
Miha Zgubic committed Oct 27, 2020
commit 22f5ae5c6e0a93d03353d5d4a36ec4245eed7ced
8 changes: 7 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
@@ -10,7 +10,13 @@ using Distributed: pmap

@nograd ones, zeros, Base.OneTo, Colon(), one, zero

@adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,)
@adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,) # TODO: need to find a way to deal with arrays in legacy2differential

# function _pullback(__context__::AContext, ::typeof(Base.vect), xs...)
# _back(::Union{Nothing,AbstractZero}) = Zero()
# _back(Δ) = (DoesNotExist(), Δ...)
# return Base.vect(xs...), _back
# end

@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)

10 changes: 5 additions & 5 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ grad_mut(ch::Channel) = Channel(ch.sz_max)
@adjoint! function put!(ch::Channel, x)
put!(ch, x), function (ȳ)
x̄ = take!(grad_mut(__context__, ch))
(nothing, accum(x̄, ȳ), nothing)
return (nothing, accum(x̄, ȳ))
end
end

@@ -64,14 +64,14 @@ end
end
end

@adjoint! function Task(f)
function _pullback(__context__::AContext, ::Type{<:Task}, f)
t = Task(f)
t.code = function ()
y, back = _pullback(__context__, f)
cache(__context__)[t] = Task(back)
y, _back = _pullback(__context__, f)
cache(__context__)[t] = Task(_back)
return y
end
t, _ -> fetch(cache(__context__)[t])
return t, _ -> (DoesNotExist(), fetch(cache(__context__)[t]))
end

function runadjoint(cx, t, ȳ = DoesNotExist())
42 changes: 20 additions & 22 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
@@ -206,26 +206,25 @@ function deref!(x::Ref)
return d
end

@generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...)

@generated nt_zero(x) = Expr(:tuple, [:($f=Zero()) for f in fieldnames(x)]...)

@generated pair(::Val{k}, v) where k = :($k = v,)

@adjoint function literal_getproperty(x, ::Val{f}) where f # TODO rewrite as explicit pullback
val = getproperty(x, f)
function back(Δ)
accum_param(__context__, val, Δ) isa AbstractZero && return
if isimmutable(x)
((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (;dx[]...,pair(Val(f),accum(getfield(dx[], f), Δ))...)
return (dx,nothing)
function _pullback(__context__::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f
val = getproperty(x, f)
_back(::Union{Nothing,AbstractZero}) = Zero()
function _back(Δ)
accum_param(__context__, val, Δ) isa AbstractZero && return Zero()
if isimmutable(x)
return (DoesNotExist(), Composite{typeof(x)}(;f => Δ), DoesNotExist())
else
dx = grad_mut(__context__, x)
dx[] += Composite{typeof(x)}(;f => Δ) # is += the right thing to do? (a=1, b=2, :a=>3) gives (a = 3, b = 2)
return (DoesNotExist(), dx, DoesNotExist())
end
end
unwrap(val), _back
end
unwrap(val), back
end

_pullback(cx::Context, ::typeof(getproperty), x, f::Symbol) =
_pullback(cx, literal_getproperty, x, Val(f))
@@ -239,7 +238,6 @@ _pullback(cx::Context, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) wher
_pullback(cx::Context, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f =
_pullback(cx, literal_getindex, x, Val(f))

#grad_mut(x) = Ref{Any}(nt_zero(x)) # TODO
grad_mut(x::T) where T = Ref{Any}(Composite{T}())

function grad_mut(cx::Context, x)
@@ -251,13 +249,13 @@ function grad_mut(cx::Context, x)
end
end

@adjoint! function setfield!(x, f, val) # TODO change to _pullback
function _pullback(__context__::AContext, ::typeof(setfield!), x, f, val)
y = setfield!(x, f, val)
g = grad_mut(__context__, x)
y, function (_)
Δ = differential2legacy(getfield(g[], f))
g[] = (;g[]...,pair(Val(f),Zero())...)
(nothing, nothing, Δ)
y, function _back(_)
Δ = getproperty(g[], f)
g[] += Composite{typeof(x)}(;f => -Δ) # i.e. g[].f = Zero(), but that is not implemented
return (DoesNotExist(), DoesNotExist(), DoesNotExist(), Δ)
end
end

@@ -284,7 +282,7 @@ const allowed_gradient_T = Union{
Nothing,
AbstractZero,
RefValue,
ChainRules.Composite#{Any, T} where T<:Union{Tuple, NamedTuple} #TODO
ChainRules.Composite
}

# TODO captured mutables + multiple calls to `back`
@@ -296,7 +294,7 @@ const allowed_gradient_T = Union{
Δ_expr = if G <: AbstractZero
elseif Δ <: RefValue
:(back.g[]) # TODO: is this right? Why don't we need to accum?
:(back.g[]) # TODO: is this right? Why don't we need to accum?
else
:(accum(back.g[], Δ))
end
2 changes: 0 additions & 2 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -18,8 +18,6 @@ end

# Complex Numbers

@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))

# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex}

@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)