Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/EnzymeTestUtils/src/EnzymeTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using FiniteDifferences: FiniteDifferences
using Random: Random
using Test

export test_forward, test_reverse, are_activities_compatible
export test_forward, test_reverse, test_rewind, are_activities_compatible

include("output_control.jl")
include("to_vec.jl")
Expand All @@ -17,5 +17,6 @@ include("finite_difference_calls.jl")
include("generate_tangent.jl")
include("test_forward.jl")
include("test_reverse.jl")
include("test_rewind.jl")

end # module
114 changes: 114 additions & 0 deletions lib/EnzymeTestUtils/src/test_rewind.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
test_rewind(f, Activity, args...; kwargs...)

Test `Enzyme.autodiff` of `f` in `Forward`-mode by backtracking using `Reverse`-mode,
which itself is checked against finite differences. This mode can be useful when computing
derivatives on functions such as matrix factorizations, where a particular choice of
gauge is important and the finite-differences approach generates tangents in an arbitrary
gauge. In effect, this plays the derivatives _forward_, then in _reverse_, "rewinding" the
Comment on lines +7 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must admit this is the first time I encounter the term gauge.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry it's physicist brain

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://en.wikipedia.org/wiki/Gauge_fixing if you're interested. I'll try to rephrase

tape.

`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with additional
constraints:
- If it mutates one of its arguments, it _must_ return that argument.

To use this test mode, `f` _must_ have both forward and reverse rules defined.

# Arguments

- `Activity`: the activity of the return value of `f`
- `args`: Each entry is either an argument to `f`, an activity type accepted by `autodiff`,
or a tuple of the form `(arg, Activity)`, where `Activity` is the activity type of
`arg`. If the activity type specified requires a tangent, a random tangent will be
automatically generated.

# Keywords

- `rng::AbstractRNG`: The random number generator to use for generating random tangents.
- `fdm=FiniteDifferences.central_fdm(5, 1)`: The finite differences method to use.
- `fkwargs`: Keyword arguments to pass to `f`.
- `rtol`: Relative tolerance for `isapprox`.
- `atol`: Absolute tolerance for `isapprox`.
- `testset_name`: Name to use for a testset in which all tests are evaluated.
- `output_tangent`: Optional final tangent to provide at the beginning of the reverse-mode differentiation

# Examples

Here we test a rule for a function of scalars. Because we don't provide an activity
annotation for `y`, it is assumed to be `Const`.

```julia
using Enzyme, EnzymeTestUtils

x, y = randn(2)
for Tret in (Const, Duplicated, DuplicatedNoNeed), Tx in (Const, Duplicated)
test_forward(*, Tret, (x, Tx), y)
end
```

Here we test a rule for a function of an array in batch forward-mode:

```julia
x = randn(3)
y = randn()
for Tret in (Const, BatchDuplicated, BatchDuplicatedNoNeed),
Tx in (Const, BatchDuplicated),
Ty in (Const, BatchDuplicated)

test_forward(*, Tret, (x, Tx), (y, Ty))
end
```
"""

function test_rewind(
f,
fwd_ret_activity,
rvs_ret_activity,
args...;
rng::Random.AbstractRNG=Random.default_rng(),
fdm=FiniteDifferences.central_fdm(5, 1),
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
atol::Real=1e-9,
testset_name=nothing,
runtime_activity::Bool=false,
output_tangent=nothing,
)
# first, test reverse as normal with finite differences
test_reverse(f, rvs_ret_activity, args...; rng=rng, fdm=fdm, fkwargs=fkwargs, rtol=rtol, atol=atol, testset_name=testset_name, runtime_activity=runtime_activity, output_tangent=output_tangent)
# now, use the reverse rule to compare with the forward result
if testset_name === nothing
testset_name = "test_rewind: $f with return activity $fwd_ret_activity on $(_string_activity(args))"
end
@testset "$testset_name" begin
# test reverse rule to make sure it works with FD
# run fwd mode first

# format arguments for autodiff and FiniteDifferences
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
primals = map(x -> x.val, activities)
# call primal, avoid mutating original arguments
fcopy = deepcopy(first(primals))
args_copy = deepcopy(Base.tail(primals))
y = fcopy(args_copy...; deepcopy(fkwargs)...)
mode = if fwd_ret_activity <: Union{DuplicatedNoNeed, BatchDuplicatedNoNeed, Const}
Forward
else
ForwardWithPrimal
end
mode = set_runtime_activity(mode, runtime_activity)
ret_activity2 = if fwd_ret_activity <: DuplicatedNoNeed
Duplicated
elseif fwd_ret_activity <: BatchDuplicatedNoNeed
BatchDuplicated
else
fwd_ret_activity
end
call_with_kwargs(f, xs...) = f(xs...; fkwargs...)
y_and_dy_ad = autodiff(mode, call_with_kwargs, ret_activity2, activities...)
dy_ad = y_and_dy_ad[1]
# now run this back through reverse mode, using dy_ad from forward mode
# as the output tangent
test_reverse(f, rvs_ret_activity, args...; rng=rng, fdm=fdm, fkwargs=fkwargs, rtol=rtol, atol=atol, testset_name=testset_name, runtime_activity=runtime_activity, output_tangent=dy_ad)
end
end
Loading