Skip to content

Commit 92987b6

Browse files
kshyattKatharine Hyatt
andauthored
Little bit more documentation for the noble @easy_rule (#2695)
* Add a really simple example to @easy_rule * Another easy rule example and xref --------- Co-authored-by: Katharine Hyatt <[email protected]>
1 parent 2cd186b commit 92987b6

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ Enzyme also supports a second way to mark things inactive, where the marker is "
307307
EnzymeRules.inactive_noinl(::typeof(det), ::UnitaryMatrix) = true
308308
```
309309

310-
### Easy Rules
310+
### [Easy Rules](@id man-easy-rule)
311311

312312
The recommended way for writing rules for most use cases is through the [`EnzymeRules.@easy_rule`](@ref) macro. This macro enables users to write derivatives for any functions which only read from their arguments (e.g. do not overwrite memory), and has numbers, matricies of numbers, or tuples thereof as arguments/result types.
313313

lib/EnzymeCore/src/easyrules.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,31 @@ If a specific argument has no partial derivative, then all corresponding argumen
719719
...)
720720
```
721721
722+
# Examples
723+
724+
Let's write an `@easy_rule` for a simple trigonometric function. Enzyme already has rules for `sin` and `cos`, but for the sake of illustration we can define a new pass-through function to oen of them to demostrate the `@easy_rule` interface.
725+
726+
```julia
727+
mycos(x) = cos(x)
728+
729+
# forward-mode rule for the new function
730+
EnzymeRules.@easy_rule(
731+
mycos(x::AbstractFloat),
732+
@setup(),
733+
(-sin(x),)
734+
)
735+
```
736+
737+
Then this rule can be tested by running `Enzyme.autodiff` and comparing with the result for the regular `cos` function:
738+
739+
```julia
740+
myderiv = autodiff(Forward, mycos, Duplicated(2.0f0, 1.2f0))[1] # -1.091157f0
741+
truederiv = autodiff(Forward, cos, Duplicated(2.0f0, 1.2f0))[1] # -1.091157f0
742+
@assert myderiv = truederiv
743+
```
744+
745+
For more information about easy rules, see the [manual](@ref man-easy-rule).
746+
722747
"""
723748
macro easy_rule(call, maybe_setup, partials...)
724749
call, setup_stmts, inputs, input_names, normal_inputs, partials = _normalize_scalarrules_macro_input(
@@ -735,4 +760,4 @@ macro easy_rule(call, maybe_setup, partials...)
735760
$(frule_expr)
736761
$(rrule_expr)
737762
end
738-
end
763+
end

0 commit comments

Comments
 (0)