diff --git a/docs/src/index.md b/docs/src/index.md index 13ff1d7bb5..b69b9106c6 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -307,7 +307,7 @@ Enzyme also supports a second way to mark things inactive, where the marker is " EnzymeRules.inactive_noinl(::typeof(det), ::UnitaryMatrix) = true ``` -### Easy Rules +### [Easy Rules](@id man-easy-rule) The recommended way for writing rules for most use cases is through the [`EnzymeRules.@easy_rule`](@ref) macro. This macro enables users to write derivatives for any functions which only read from their arguments (e.g. do not overwrite memory), and has numbers, matricies of numbers, or tuples thereof as arguments/result types. diff --git a/lib/EnzymeCore/src/easyrules.jl b/lib/EnzymeCore/src/easyrules.jl index 58bcd49101..e93f2283ac 100644 --- a/lib/EnzymeCore/src/easyrules.jl +++ b/lib/EnzymeCore/src/easyrules.jl @@ -719,6 +719,31 @@ If a specific argument has no partial derivative, then all corresponding argumen ...) ``` +# Examples + +Let's write an `@easy_rule` for a simple trigonometric function. Enzyme already has rules for `sin` and `cos`, but for the sake of illustration we can define a new pass-through function to oen of them to demostrate the `@easy_rule` interface. + +```julia +mycos(x) = cos(x) + +# forward-mode rule for the new function +EnzymeRules.@easy_rule( + mycos(x::AbstractFloat), + @setup(), + (-sin(x),) +) +``` + +Then this rule can be tested by running `Enzyme.autodiff` and comparing with the result for the regular `cos` function: + +```julia +myderiv = autodiff(Forward, mycos, Duplicated(2.0f0, 1.2f0))[1] # -1.091157f0 +truederiv = autodiff(Forward, cos, Duplicated(2.0f0, 1.2f0))[1] # -1.091157f0 +@assert myderiv = truederiv +``` + +For more information about easy rules, see the [manual](@ref man-easy-rule). + """ macro easy_rule(call, maybe_setup, partials...) call, setup_stmts, inputs, input_names, normal_inputs, partials = _normalize_scalarrules_macro_input( @@ -735,4 +760,4 @@ macro easy_rule(call, maybe_setup, partials...) $(frule_expr) $(rrule_expr) end -end \ No newline at end of file +end