diff --git a/Project.toml b/Project.toml index e60124086..114038e21 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.25.0" +version = "1.25.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 236ef49f1..2556c5b85 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -1,3 +1,7 @@ +# Disable thunks for 2nd order AD. +_usethunks() = true +rrule(::typeof(_usethunks)) = false, Returns((NoTangent(),)) + abstract type AbstractThunk <: AbstractTangent end struct MutateThunkException <: Exception end @@ -141,7 +145,11 @@ macro thunk(body) # Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined. # so we get useful stack traces if it errors. func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body)) - return :(Thunk($(esc(func)))) + return quote + _usethunks() ? + Thunk($(esc(func))) : + $(esc(body)) + end end """ @@ -233,6 +241,12 @@ and destroy its inplacability. struct InplaceableThunk{T<:Thunk,F} <: AbstractThunk add!::F val::T + + function InplaceableThunk(add!::F, val::T) where {F, T} + _usethunks() ? + new{T, F}(add!, val) : + val + end end unthunk(x::InplaceableThunk) = unthunk(x.val)