Skip to content

Commit

Permalink
Merge branch 'master' into bc/rm-gpu-sum-adj
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Jan 4, 2025
2 parents de07572 + bb3730e commit 7b59497
Show file tree
Hide file tree
Showing 32 changed files with 1,028 additions and 503 deletions.
45 changes: 23 additions & 22 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.64"
version = "0.7.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -18,42 +18,48 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
ZygoteColorsExt = "Colors"
ZygoteDistancesExt = "Distances"
ZygoteTrackerExt = "Tracker"

[compat]
AbstractFFTs = "1.3.1"
ChainRules = "1.44.1"
ChainRulesCore = "1.9"
ChainRules = "1.72.2"
ChainRulesCore = "1.25.1"
ChainRulesTestUtils = "1"
Colors = "0.12"
Colors = "0.12, 0.13"
DiffRules = "1.4"
Distances = "0.10"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
ForwardDiff = "0.10"
GPUArrays = "8.4.2, 9"
GPUArraysCore = "0.1.1"
IRTools = "0.4.4"
GPUArrays = "8.4.2, 9, 10, 11"
GPUArraysCore = "0.1.1, 0.2"
IRTools = "0.4.12"
LogExpFunctions = "0.3.1"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Requires = "1.1"
PrecompileTools = "1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.1"
ZygoteRules = "0.2.5"
julia = "1.6"

[extensions]
ZygoteColorsExt = "Colors"
ZygoteDistancesExt = "Distances"
ZygoteTrackerExt = "Tracker"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand All @@ -62,14 +68,9 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PyCall", "Test"]

[weakdeps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PythonCall", "Test"]
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Zygote

Welcome! Zygote extends the Julia language to support [differentiable programming](https://fluxml.ai/blog/2019/02/07/what-is-differentiable-programming.html). With Zygote you can write down any Julia code you feel like – including using existing Julia packages – then get gradients and optimise your program. Deep learning, ML and probabilistic programming are all different kinds of differentiable programming that you can do with Zygote.
Welcome! Zygote extends the Julia language to support [differentiable programming](https://fluxml.ai/blogposts/2019-02-07-what-is-differentiable-programming/). With Zygote you can write down any Julia code you feel like – including using existing Julia packages – then get gradients and optimise your program. Deep learning, ML and probabilistic programming are all different kinds of differentiable programming that you can do with Zygote.

At least, that's the idea. We're still in beta so expect some adventures.

Expand Down
35 changes: 19 additions & 16 deletions docs/src/limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,30 @@ julia> gradient(rand(3)) do y

## Try-catch statements

Any expressions involving `try`/`catch` statements is not supported.
```julia
function tryme(x)
try
2 * x
catch e
throw(e)
end
end
Code containting try-catch blocks can be differentiated as long as no exception is actually thrown.

julia> gradient(rand(3)) do x
sum(tryme(x))
```julia
julia> function safe_sqrt(x)
try
sqrt(x)
catch
0.
end
end
ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
safe_sqrt (generic function with 1 method)

julia> gradient(safe_sqrt, 4.)
(0.25,)

julia> val, pull = pullback(safe_sqrt, -1.)
(0.0, Zygote.var"#76#77"{Zygote.Pullback{Tuple{typeof(safe_sqrt), Float64}, Any}}((safe_sqrt)))

julia> pull(1.)
ERROR: Can't differentiate function execution in catch block at #= REPL[2]:3 =#.
Stacktrace:
...
```
Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages.

Here, the `safe_sqrt` function catches DomainError from the sqrt call when the input is out of domain and safely returns 0. Zygote is able to differentiate the function when no error is thrown by the sqrt call, but fails to differentiate when the control flow goes through the catch block.

## Foreign call expressions

Expand Down
6 changes: 6 additions & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
```@meta
CollapsedDocStrings = true
```


# Utilities

Zygote's gradients can be used to construct a Jacobian (by repeated evaluation)
Expand Down Expand Up @@ -26,6 +31,7 @@ Zygote.hook
Zygote.Buffer
Zygote.forwarddiff
Zygote.checkpointed
Zygote.eager_update!
```

`Params` and `Grads` can be copied to and from arrays using the `copy!` function.
Expand Down
3 changes: 2 additions & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
17 changes: 15 additions & 2 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@non_differentiable unthunk_tangent(::IdDict)


struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
context::CTX
end
Expand Down Expand Up @@ -107,7 +118,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
@inline wrap_chainrules_output(x) = x
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
Expand Down Expand Up @@ -162,6 +172,7 @@ end
# For arrays, whitelist the safe ones, but always look inside Any[]:
@inline wrap_chainrules_input(dxs::AbstractArray{<:Number}) = dxs
@inline wrap_chainrules_input(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
@inline wrap_chainrules_input(dxs::AbstractArray{<:Union{Nothing,T}}) where T <: Number = map(x -> x === nothing ? zero(T) : x, dxs)
@inline wrap_chainrules_input(dxs::AbstractArray) = map(wrap_chainrules_input, dxs)

#=
Expand Down Expand Up @@ -260,7 +271,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
_pullback(config.context, f_args...)
end

ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
ad_pullback(Δ) = zygote2differential(
pb(wrap_chainrules_output(unthunk_tangent(Δ))),
f_args)
return y, ad_pullback
end

Expand Down
14 changes: 7 additions & 7 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ concrete(T::DataType) = T
concrete(::Type{Type{T}}) where T = typeof(T)
concrete(T) = Any

runonce(b) = b.id in (1, length(b.ir.blocks))
runonce(b) = b.id in (1, length(b.ir.blocks)) &&
!any(((_,stmt),) -> isexpr(stmt.expr, :catch), b)

function forward_stacks!(adj, F)
stks, recs = [], []
pr = adj.primal
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
if runonce(b)
not_stack = runonce(b)
if not_stack
push!(recs, Variable(α))
else
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
end
push!(stks, (b.id, alpha(α)))
push!(stks, (b.id, alpha(α), not_stack))
end
args = arguments(pr)[3:end]
rec = push!(pr, xtuple(recs...))
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
# P = Pullback{F,Any} # reduce specialisation
Expand All @@ -68,11 +69,10 @@ function reverse_stacks!(adj, stks)
self = argument!(entry, at = 1)
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
repl = Dict()
runonce(b) = b.id in (1, length(ir.blocks))
for b in blocks(ir)
for (i, (b′, α)) in enumerate(stks)
for (i, (b′, α, not_stack)) in enumerate(stks)
b.id == b′ || continue
if runonce(b)
if not_stack
val = insertafter!(ir, t, xcall(:getindex, t, i))
else
stk = push!(entry, xcall(:getindex, t, i))
Expand Down
Loading

0 comments on commit 7b59497

Please sign in to comment.