1
1
using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
2
2
blocks, predecessors, successors, argument!, arguments, branches,
3
- exprtype, insertafter!, finish, expand!, prune!, substitute!, substitute,
4
- block, block!, branch!, return !, stmt
3
+ insertafter!, finish, expand!, prune!, substitute!, substitute,
4
+ block, block!, branch!, return !, stmt, meta
5
5
using Base: @get!
6
6
7
7
@inline tuple_va (N, xs) = xs
@@ -75,7 +75,7 @@ function instrument_global!(ir, v, ex)
75
75
else
76
76
ir[v] = prewalk (ex) do x
77
77
istrackable (x) || return x
78
- insert! (ir, v, stmt ( xcall (Zygote, :unwrap , QuoteNode (x), x), type = exprtype (x) ))
78
+ insert! (ir, v, xcall (Zygote, :unwrap , QuoteNode (x), x))
79
79
end
80
80
end
81
81
end
@@ -125,16 +125,6 @@ ignored_f(ir, f::Variable) = ignored_f(get(ir, f, nothing))
125
125
ignored (ir, ex) = isexpr (ex, :call ) && ignored_f (ir, ex. args[1 ])
126
126
ignored (ir, ex:: Variable ) = ignored (ir, ir[ex])
127
127
128
- # TODO : remove this once we don't mess with type inference
129
- function _forward_type (Ts)
130
- usetyped || return Any
131
- all (T -> isconcretetype (T) || T <: DataType , Ts) || return Any
132
- T = Core. Compiler. return_type (_pullback, Tuple{Context,Ts... })
133
- return T == Union{} ? Any : T
134
- end
135
-
136
- isvalidtype (jT, yT) = jT <: Tuple && length (jT. parameters) == 2 && jT. parameters[1 ] <: yT
137
-
138
128
function primal (ir:: IR )
139
129
pr = Pipe (ir)
140
130
pbs = Dict {Variable,Variable} ()
@@ -143,23 +133,12 @@ function primal(ir::IR)
143
133
for (v, st) in pr
144
134
ex = st. expr
145
135
if isexpr (ex, :call ) && ! ignored (ir, ex)
146
- yT = exprtype (ir, v)
147
- T = _forward_type (exprtype .((ir,), ex. args))
148
- if yT == Any || isvalidtype (T, yT)
149
- yJ = insert! (pr, v, stmt (xcall (Zygote, :_pullback , cx, ex. args... ),
150
- line = ir[v]. line))
151
- pr[v] = xgetindex (yJ, 1 )
152
- J = insertafter! (pr, v, stmt (xgetindex (yJ, 2 ),
153
- type = T == Any ? Any : T. parameters[2 ],
154
- line = ir[v]. line))
155
- pbs[v] = substitute (pr, J)
156
- else
157
- yJ = insert! (pr, v, xcall (Zygote, :_pullback , cx, ex. args... ))
158
- y = insert! (pr, v, xgetindex (yJ, 1 ))
159
- J = insert! (pr, v, stmt (xgetindex (yJ, 2 ), line = ir[v]. line))
160
- pr[v] = xcall (Zygote, :typeassert , y, yT)
161
- pbs[v] = substitute (pr, J)
162
- end
136
+ yJ = insert! (pr, v, stmt (xcall (Zygote, :_pullback , cx, ex. args... ),
137
+ line = ir[v]. line))
138
+ pr[v] = xgetindex (yJ, 1 )
139
+ J = insertafter! (pr, v, stmt (xgetindex (yJ, 2 ),
140
+ line = ir[v]. line))
141
+ pbs[v] = substitute (pr, J)
163
142
end
164
143
end
165
144
pr = finish (pr)
0 commit comments