From 600e78fb897fc887acd8ca192747ddf65ecfa519 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 14 Feb 2023 23:37:02 -0500 Subject: [PATCH] Optimize power rule --- src/rulesets/Base/fastmath_able.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index d8afb630a..18ee18151 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -181,8 +181,14 @@ let ## power # literal_pow is in base.jl function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number) - y = x ^ p - _dx = _pow_grad_x(x, p, float(y)) + if isinteger(p) + tmp = x ^ (p - 1) + y = x * tmp + _dx = p * tmp + else + y = x ^ p + _dx = _pow_grad_x(x, p, float(y)) + end if iszero(Δp) # Treat this as a strong zero, to avoid NaN, and save the cost of log return y, _dx * Δx