Skip to content

Commit 5e5f6bb

Browse files
committed
rm typed mode
1 parent 63f1816 commit 5e5f6bb

11 files changed

+33
-125
lines changed

.travis.yml

-5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ notifications:
1313
git:
1414
depth: 99999999
1515

16-
env:
17-
matrix:
18-
- ZYGOTE_TYPED=true
19-
- ZYGOTE_TYPED=false
20-
2116
## uncomment the following lines to allow failures on nightly julia
2217
## (tests will run but not make your overall status red)
2318
matrix:

src/Zygote.jl

+1-10
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,6 @@ using LinearAlgebra: copytri!, AbstractTriangular
55

66
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback
77

8-
# This flag enables Zygote to grab extra type inference information during
9-
# compiles. When control flow is present, this can give gradient code a
10-
# performance boost.
11-
12-
# HOWEVER, this is not Jameson-approved, nor well supported by the compiler, and
13-
# has several caveats. Recursion will cause inference to stack overflow.
14-
# Gradient redefinitions may result in ugly type errors. And Jameson *will* know.
15-
const usetyped = get(ENV, "ZYGOTE_TYPED", false) == "true"
16-
178
using IRTools
189
using MacroTools, Requires
1910
using MacroTools: @forward
@@ -49,7 +40,7 @@ include("profiler/Profile.jl")
4940
include("flux.jl")
5041
end
5142

52-
precompile() = usetyped || include(joinpath(@__DIR__, "precompile.jl"))
43+
precompile() = include(joinpath(@__DIR__, "precompile.jl"))
5344

5445
# precompile()
5546
@init Requires.isprecompiling() || precompile()

src/compiler/emit.jl

+6-23
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
# Emit
2424

25-
xstack(T) = stmt(Expr(:call, Vector{T}), type = Vector{T})
25+
xstack(T) = Expr(:call, Vector{T})
2626

2727
function alphauses(b)
2828
us = Set{Alpha}()
@@ -45,24 +45,17 @@ function forward_stacks!(adj, F)
4545
if runonce(b)
4646
push!(recs, Variable(α))
4747
else
48-
T = exprtype(pr, Variable(α))
49-
stk = pushfirst!(pr, xstack(T))
48+
stk = pushfirst!(pr, xstack(Any))
5049
push!(recs, stk)
5150
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
5251
end
5352
push!(stks, (b.id, alpha(α)))
5453
end
5554
args = arguments(pr)[3:end]
56-
T = Tuple{concrete.(exprtype.((pr,), recs))...}
57-
isconcretetype(T) || (T = Any)
5855
rec = push!(pr, xtuple(recs...))
59-
if usetyped && length(pr.blocks) > 1
60-
rec = push!(pr, Expr(:call, Pullback{F,T}, rec))
61-
else
62-
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
63-
# P = Pullback{F,Any} # reduce specialisation
64-
rec = push!(pr, Expr(:call, P, rec))
65-
end
56+
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
57+
# P = Pullback{F,Any} # reduce specialisation
58+
rec = push!(pr, Expr(:call, P, rec))
6659
ret = xtuple(pr.blocks[end].branches[end].args[1], rec)
6760
ret = push!(pr, ret)
6861
pr.blocks[end].branches[end].args[1] = ret
@@ -102,18 +95,8 @@ end
10295

10396
varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing
10497

105-
meta(T) = (usetyped ? IRTools.typed_meta : IRTools.meta)(T)
106-
107-
function getmeta(T)
108-
m = meta(T)
109-
(usetyped && m != nothing) || return m
110-
any(x -> isexpr(x, :goto, :gotoifnot), m.code.code) || return IRTools.meta(T)
111-
return m
112-
end
113-
11498
function _lookup_grad(T)
115-
(m = getmeta(T)) == nothing && return
116-
m isa IRTools.TypedMeta && m.ret == Union{} && return
99+
(m = meta(T)) == nothing && return
117100
va = varargs(m.method, length(T.parameters))
118101
forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T)
119102
m, forw, back

src/compiler/reverse.jl

+9-30
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
22
blocks, predecessors, successors, argument!, arguments, branches,
3-
exprtype, insertafter!, finish, expand!, prune!, substitute!, substitute,
4-
block, block!, branch!, return!, stmt
3+
insertafter!, finish, expand!, prune!, substitute!, substitute,
4+
block, block!, branch!, return!, stmt, meta
55
using Base: @get!
66

77
@inline tuple_va(N, xs) = xs
@@ -75,7 +75,7 @@ function instrument_global!(ir, v, ex)
7575
else
7676
ir[v] = prewalk(ex) do x
7777
istrackable(x) || return x
78-
insert!(ir, v, stmt(xcall(Zygote, :unwrap, QuoteNode(x), x), type = exprtype(x)))
78+
insert!(ir, v, xcall(Zygote, :unwrap, QuoteNode(x), x))
7979
end
8080
end
8181
end
@@ -125,16 +125,6 @@ ignored_f(ir, f::Variable) = ignored_f(get(ir, f, nothing))
125125
ignored(ir, ex) = isexpr(ex, :call) && ignored_f(ir, ex.args[1])
126126
ignored(ir, ex::Variable) = ignored(ir, ir[ex])
127127

128-
# TODO: remove this once we don't mess with type inference
129-
function _forward_type(Ts)
130-
usetyped || return Any
131-
all(T -> isconcretetype(T) || T <: DataType, Ts) || return Any
132-
T = Core.Compiler.return_type(_pullback, Tuple{Context,Ts...})
133-
return T == Union{} ? Any : T
134-
end
135-
136-
isvalidtype(jT, yT) = jT <: Tuple && length(jT.parameters) == 2 && jT.parameters[1] <: yT
137-
138128
function primal(ir::IR)
139129
pr = Pipe(ir)
140130
pbs = Dict{Variable,Variable}()
@@ -143,23 +133,12 @@ function primal(ir::IR)
143133
for (v, st) in pr
144134
ex = st.expr
145135
if isexpr(ex, :call) && !ignored(ir, ex)
146-
yT = exprtype(ir, v)
147-
T = _forward_type(exprtype.((ir,), ex.args))
148-
if yT == Any || isvalidtype(T, yT)
149-
yJ = insert!(pr, v, stmt(xcall(Zygote, :_pullback, cx, ex.args...),
150-
line = ir[v].line))
151-
pr[v] = xgetindex(yJ, 1)
152-
J = insertafter!(pr, v, stmt(xgetindex(yJ, 2),
153-
type = T == Any ? Any : T.parameters[2],
154-
line = ir[v].line))
155-
pbs[v] = substitute(pr, J)
156-
else
157-
yJ = insert!(pr, v, xcall(Zygote, :_pullback, cx, ex.args...))
158-
y = insert!(pr, v, xgetindex(yJ, 1))
159-
J = insert!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))
160-
pr[v] = xcall(Zygote, :typeassert, y, yT)
161-
pbs[v] = substitute(pr, J)
162-
end
136+
yJ = insert!(pr, v, stmt(xcall(Zygote, :_pullback, cx, ex.args...),
137+
line = ir[v].line))
138+
pr[v] = xgetindex(yJ, 1)
139+
J = insertafter!(pr, v, stmt(xgetindex(yJ, 2),
140+
line = ir[v].line))
141+
pbs[v] = substitute(pr, J)
163142
end
164143
end
165144
pr = finish(pr)

src/flux.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using .Tracker: TrackedArray, TrackedReal
22

3-
if !usetyped
4-
unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x)
5-
end
3+
unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x)
64

75
pullback(f, ps::Tracker.Params) = pullback(f, Params(ps))

test/compiler.jl

+1-8
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ end
1919
bad(x) = x
2020
@adjoint bad(x) = x, Δ -> error("bad")
2121

22-
Zygote.usetyped && Zygote.refresh()
23-
2422
function badly(x)
2523
x = x + 1
2624
x = bad(x)
@@ -33,7 +31,7 @@ y, back = pullback(badly, 2)
3331
bt = try back(1) catch e stacktrace(catch_backtrace()) end
3432

3533
@test trace_contains(bt, nothing, "compiler.jl", 20)
36-
@test trace_contains(bt, :badly, "compiler.jl", 26)
34+
@test trace_contains(bt, :badly, "compiler.jl", 24)
3735

3836
# Type inference checks
3937

@@ -76,8 +74,3 @@ y, back = @test_inferred pullback(x->x[1], (5,:a))
7674

7775
y, back = @test_inferred pullback(((a,b),) -> a, (5, 10))
7876
@test_inferred back(1)
79-
80-
# Checks that use control flow
81-
if Zygote.usetyped
82-
include("typed.jl")
83-
end

test/complex.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ using Zygote, Test
77

88
@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
99
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
10-
if !Zygote.usetyped
11-
@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im #TODO
12-
@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im #TODO
13-
end
10+
@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im
11+
@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
1412
@test gradient(a -> real.(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
1513

1614
fs_C_to_R = (real,

test/features.jl

+11-20
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ end == (4,)
137137

138138
pow_rec(x, n) = n == 0 ? 1 : x*pow_rec(x, n-1)
139139

140-
if !Zygote.usetyped
141-
@test gradient(pow_rec, 2, 3) == (12, nothing)
142-
end
140+
@test gradient(pow_rec, 2, 3) == (12, nothing)
143141

144142
# For nested AD, until we support errors
145143
function grad(f, args...)
@@ -187,17 +185,14 @@ y, back = pullback(() -> layer(x), Params([W]))
187185
sum(H)
188186
end[1] == 1
189187

190-
# FIXME
191-
if !Zygote.usetyped
192-
@test gradient(2) do x
193-
if x < 0
194-
throw("foo")
195-
end
196-
return x*5
197-
end[1] == 5
188+
@test gradient(2) do x
189+
if x < 0
190+
throw("foo")
191+
end
192+
return x*5
193+
end[1] == 5
198194

199-
@test gradient(x -> one(eltype(x)), rand(10))[1] == nothing
200-
end
195+
@test gradient(x -> one(eltype(x)), rand(10))[1] == nothing
201196

202197
# Thre-way control flow merge
203198
@test gradient(1) do x
@@ -214,15 +209,11 @@ grad_closure(x) = 2x
214209

215210
Zygote.@adjoint (f::typeof(grad_closure))(x) = f(x), Δ -> (1, 2)
216211

217-
Zygote.usetyped && Zygote.refresh()
218-
219212
@test gradient((f, x) -> f(x), grad_closure, 5) == (1, 2)
220213

221-
if !Zygote.usetyped
222-
invokable(x) = 2x
223-
invokable(x::Integer) = 3x
224-
@test gradient(x -> invoke(invokable, Tuple{Any}, x), 5) == (2,)
225-
end
214+
invokable(x) = 2x
215+
invokable(x::Integer) = 3x
216+
@test gradient(x -> invoke(invokable, Tuple{Any}, x), 5) == (2,)
226217

227218
y, back = Zygote.pullback(x->tuple(x...), [1, 2, 3])
228219
@test back((1, 1, 1)) == ((1,1,1),)

test/gradcheck.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121
bar = (x) -> x*y
122122
sum(map(bar, 1:5))
123123
end
124-
Zygote.usetyped || @test gradtest(foo, 3) #TODO
124+
@test gradtest(foo, 3)
125125
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
126126
end
127127

@@ -711,9 +711,7 @@ end
711711
end
712712

713713
@testset "broadcast" begin
714-
if !Zygote.usetyped
715-
@test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
716-
end
714+
@test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
717715
end
718716

719717
using Zygote: Buffer

test/runtests.jl

-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
using Zygote, Test
22
using Zygote: gradient
33

4-
if Zygote.usetyped
5-
@info "Testing Zygote in type-hacks mode."
6-
else
7-
@info "Testing Zygote in normal mode."
8-
end
9-
104
@testset "Zygote" begin
115

126
@testset "Features" begin

test/typed.jl

-12
This file was deleted.

0 commit comments

Comments
 (0)