From 245de579b95180a1089c4d43573baf937aa5121c Mon Sep 17 00:00:00 2001 From: Mao Zeng Date: Tue, 7 Jan 2025 13:41:58 +0000 Subject: [PATCH 1/4] Add activation functions telu and telu_fast. --- docs/src/reference.md | 2 ++ src/activations.jl | 62 +++++++++++++++++++++++++++++++++++++++++-- test/activations.jl | 53 ++++++++++++++++++++++++++++++++---- 3 files changed, 110 insertions(+), 7 deletions(-) diff --git a/docs/src/reference.md b/docs/src/reference.md index 5edde7195..835bd1133 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -31,6 +31,8 @@ swish hardswish tanhshrink trelu +telu +telu_fast ``` ## Attention diff --git a/src/activations.jl b/src/activations.jl index 4ed58622a..5fdc536d7 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -7,8 +7,8 @@ ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, - :mish, :tanhshrink, :softshrink, :trelu, :lisht, - :tanh_fast, :sigmoid_fast, + :mish, :tanhshrink, :softshrink, :trelu, :lisht, :telu, + :tanh_fast, :sigmoid_fast, :telu_fast ] # of type float (to allow for integer inputs) @@ -749,6 +749,62 @@ function softshrink(x, λ = 0.5) ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi) end +""" + telu(x) = x * tanh(exp(x)) + +See e.g. ["TeLU Activation Function for Fast and Stable Deep Learning"](https://arxiv.org/abs/2412.20269). + +```julia-repl +julia> lineplot(telu, -2, 2, height=7) + ┌────────────────────────────────────────┐ + 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ telu(x) + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│ + f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⡧⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ + │⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠊⠉⠉⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + └────────────────────────────────────────┘ + ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ + ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ + +julia> lineplot(telu, -5, 0, height=7) + ┌────────────────────────────────────────┐ + 0 │⠤⠤⢄⣀⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡸│ telu(x) + │⠀⠀⠀⠀⠀⠈⠉⠉⠒⠒⠤⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠇│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠢⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀│ + f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠦⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⠤⣀⡀⠀⠀⠀⣀⡤⠃⠀⠀⠀⠀│ + -0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀│ + └────────────────────────────────────────┘ + ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀ + ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ + +``` +""" +telu(x) = x * tanh(exp(x)) + +""" + telu_fast(x) + +This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative, +`deriv_telu_fast`, which is faster but less accurate that `deriv_telu`. + """ +telu_fast(x) = x * tanh_fast(exp(x)) + +# Adapted from the Discourse post: +function deriv_telu(x) + exp_x = exp(x) + tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2 +end + +function deriv_telu_fast(x) + tanh_exp_x = tanh(exp(x)) + sech_exp_x_squared = 1 - tanh_exp_x^2 + ifelse(x >= 4, one(x), tanh_exp_x + x * exp(x) * sech_exp_x_squared) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy +end + # Define broadcasts for activation functions on arrays for f in ACTIVATIONS @eval $(f)(x::AbstractArray, args...) = $(f).(x, args...) @@ -888,9 +944,11 @@ UNARY_ACTS = [ # f, dfdx # mish (:tanhshrink, :((x - Ω)^2)), (:softshrink, :(Ω != 0)), + (:telu, :(deriv_telu(x))), ## Fast variants are the same! (:tanh_fast, :(conj(1 - Ω^2))), (:sigmoid_fast, :(conj(Ω * (1 - Ω)))), + (:telu_fast, :(deriv_telu_fast(x))) ] for (f, dfdx) in UNARY_ACTS diff --git a/test/activations.jl b/test/activations.jl index 3a14bfde8..3d7846ecf 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -26,6 +26,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test mish(0.0) == 0.0 @test tanhshrink(0.0) == 0.0 @test softshrink(0.0) == 0.0 +@test telu(0.0) == 0.0 @test sigmoid(1.0) == 1.0 / (1.0 + exp(-1.0)) @test hardsigmoid(1.0) == max(0,min(1, (1 + 3)/6)) @@ -48,6 +49,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test mish(1.0) ≈ tanh(log(1.0 + exp(1.0))) @test tanhshrink(1.0) ≈ 0.23840584404423515 @test softshrink(1.0) == 0.5 +@test telu(1.0) ≈ 0.99132891580059984 @test sigmoid(-1.0) == exp(-1.0) / (1.0 + exp(-1.0)) @test hardsigmoid(-1.0) == max(0,min(1,(-1+3)/6 )) @@ -70,6 +72,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test mish(-1.0) ≈ -tanh(log(1.0 + exp(-1.0))) @test tanhshrink(-1.0) ≈ -0.23840584404423515 @test softshrink(-1.0) == -0.5 +@test telu(-1.0) ≈ -0.35213549054658698 @testset "Float inference" begin @testset "$(a): " for a in ACTIVATION_FUNCTIONS @@ -111,10 +114,10 @@ end # Ideally +-Inf would not lead to NaN, but perhaps # these aren't worth the complication of fixing: - a == softsign && continue + a in [softsign, telu] && continue @test !isnan(a(Inf32)) - a in [gelu, swish, hardswish, logcosh, mish] && continue + a in [gelu, swish, hardswish, logcosh, mish, telu, telu_fast] && continue @test !isnan(a(-Inf32)) end end @@ -211,7 +214,7 @@ end ## Faster variants -using NNlib: tanh_fast, sigmoid_fast +using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, deriv_telu_fast function countepsfrom(x::T, xtrue) where {T<:AbstractFloat} target = T(xtrue) @@ -228,7 +231,7 @@ function find_worst(f, g, xs) c, xs[i] end -@testset "tanh_fast & sigmoid_fast: Float64" begin +@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float64" begin x64 = 1e-6:1e-4:5 xbig = vcat(6:3:200.0, 1000, 10^6, typemax(Float64)) @@ -262,9 +265,29 @@ end @test sigmoid_fast.(xbig) ≈ sigmoid.(xbig) @test sigmoid_fast.(-xbig) ≈ sigmoid.(-xbig) end + @testset "telu" begin + mean_eps(telu, telu, x64) # 0.1146 + worst_eps(telu, telu, x64) # 2 + + @test mean_eps(telu_fast, telu, x64) < 0.13 # 0.12204 + @test worst_eps(telu_fast, telu, x64) <= 3 # 2 + + @test telu_fast.(xbig[1:end-1]) ≈ telu.(xbig[1:end-1]) + @test telu_fast.(-xbig[1:end-1]) ≈ telu.(-xbig[1:end-1]) + end + @testset "deriv_telu" begin + mean_eps(deriv_telu, deriv_telu, x64) # 0.09304 + worst_eps(deriv_telu, deriv_telu, x64) # 2 + + @test mean_eps(deriv_telu_fast, deriv_telu, x64) < 2.1 # 2.05944 + @test worst_eps(deriv_telu_fast, deriv_telu, x64) <= 29 # 28 + + @test deriv_telu_fast.(xbig[1:end-1]) ≈ deriv_telu.(xbig[1:end-1]) + @test deriv_telu_fast.(-xbig[1:end-1]) ≈ deriv_telu.(-xbig[1:end-1]) + end end -@testset "tanh_fast & sigmoid_fast: Float32" begin +@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float32" begin x32 = 1f-6:1f-4:5 xbig32 = vcat(6:3:200f0, 1000, typemax(Float32)) @@ -298,6 +321,26 @@ end @test sigmoid_fast.(xbig32) ≈ sigmoid.(xbig32) @test sigmoid_fast.(-xbig32) ≈ sigmoid.(-xbig32) end + @testset "telu" begin + mean_eps(telu, telu, x32) # 0.09418 + worst_eps(telu, telu, x32) # 2 + + @test mean_eps(telu_fast, telu, x32) < 0.26 # 0.2555 + @test worst_eps(telu_fast, telu, x32) <= 5 # 4 + + @test telu_fast.(xbig32[1:end-1]) ≈ telu.(xbig32[1:end-1]) + @test telu_fast.(-xbig32[1:end-1]) ≈ telu.(-xbig32[1:end-1]) + end + @testset "deriv_telu" begin + mean_eps(deriv_telu, deriv_telu, x32) # 0.07228 + worst_eps(deriv_telu, deriv_telu, x32) # 1 + + @test mean_eps(deriv_telu_fast, deriv_telu, x32) < 0.69 # 0.68772 + @test worst_eps(deriv_telu_fast, deriv_telu, x32) <= 11 # 10 + + @test deriv_telu_fast.(xbig32[1:end-1]) ≈ deriv_telu.(xbig32[1:end-1]) + @test deriv_telu_fast.(-xbig32[1:end-1]) ≈ deriv_telu.(-xbig32[1:end-1]) + end end ## Autodiff tests From a25a1d11286854867fa09e6e853b5355e976a1c6 Mon Sep 17 00:00:00 2001 From: Mao Zeng Date: Tue, 7 Jan 2025 21:34:44 +0000 Subject: [PATCH 2/4] Add @fastmath for telu_fast and reuse telu(x) to compute telu'(x) --- src/activations.jl | 41 ++++++++++++++++++++++++++++++++++++----- test/activations.jl | 22 +++++++++++----------- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/activations.jl b/src/activations.jl index 5fdc536d7..93a85ab8a 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -791,7 +791,7 @@ telu(x) = x * tanh(exp(x)) This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative, `deriv_telu_fast`, which is faster but less accurate that `deriv_telu`. """ -telu_fast(x) = x * tanh_fast(exp(x)) +telu_fast(x) = @fastmath x * tanh_fast(exp(x)) # Adapted from the Discourse post: function deriv_telu(x) @@ -799,12 +799,43 @@ function deriv_telu(x) tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2 end -function deriv_telu_fast(x) - tanh_exp_x = tanh(exp(x)) +# 0th and 1st order Taylor expansion for telu'(x) around x=0 +const deriv_telu_taylor_expansion = (tanh(1.0), 8*exp(1)^2 / (1+exp(1)^2)^2) + +# Various cutoffs for numerical evaluations of telu'(x) +const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64)) +const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = -log(sqrt_eps_f16), -log(sqrt_eps_f32), -log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow +@inline small_x_cutoff_deriv_telu(::Float16) = sqrt_eps_f16 +@inline small_x_cutoff_deriv_telu(::Float32) = sqrt_eps_f32 +@inline small_x_cutoff_deriv_telu(::Float64) = sqrt_eps_f64 +@inline small_x_cutoff_deriv_telu(::T) where T <: AbstractFloat = sqrt(eps(T)) +@inline minus_log_cutoff(::Float16) = minus_log_cutoff_f16 +@inline minus_log_cutoff(::Float32) = minus_log_cutoff_f32 +@inline minus_log_cutoff(::Float64) = minus_log_cutoff_f64 +@inline minus_log_cutoff(::T) where T <: AbstractFloat = -log(small_x_cutoff_deriv_telu(zero(T))) + +@inline function _deriv_telu_taylor_expansion(x::T) where {T <: Union{Float16, Float32, Float64}} + convert(T, deriv_telu_taylor_expansion[1]) + x * convert(T, deriv_telu_taylor_expansion[2]) +end + +@inline function _deriv_telu_taylor_expansion(x::T) where {T <: AbstractFloat} + tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2 +end + +function deriv_telu_fast(x, Ω) + ifelse(abs(x) < small_x_cutoff_deriv_telu(x), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion + ifelse(x >= minus_log_cutoff(x), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy +end + +@inline function _deriv_telu_fast(x, Ω) + tanh_exp_x = Ω / x sech_exp_x_squared = 1 - tanh_exp_x^2 - ifelse(x >= 4, one(x), tanh_exp_x + x * exp(x) * sech_exp_x_squared) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy + tanh_exp_x + x * exp(x) * sech_exp_x_squared end +# for testing accuracy +_deriv_telu_fast(x) = deriv_telu_fast(x, telu_fast(x)) + # Define broadcasts for activation functions on arrays for f in ACTIVATIONS @eval $(f)(x::AbstractArray, args...) = $(f).(x, args...) @@ -948,7 +979,7 @@ UNARY_ACTS = [ # f, dfdx ## Fast variants are the same! (:tanh_fast, :(conj(1 - Ω^2))), (:sigmoid_fast, :(conj(Ω * (1 - Ω)))), - (:telu_fast, :(deriv_telu_fast(x))) + (:telu_fast, :(deriv_telu_fast(x, Ω))) ] for (f, dfdx) in UNARY_ACTS diff --git a/test/activations.jl b/test/activations.jl index 3d7846ecf..3d303fbce 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -214,7 +214,7 @@ end ## Faster variants -using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, deriv_telu_fast +using NNlib: tanh_fast, sigmoid_fast, telu_fast, deriv_telu, _deriv_telu_fast function countepsfrom(x::T, xtrue) where {T<:AbstractFloat} target = T(xtrue) @@ -269,8 +269,8 @@ end mean_eps(telu, telu, x64) # 0.1146 worst_eps(telu, telu, x64) # 2 - @test mean_eps(telu_fast, telu, x64) < 0.13 # 0.12204 - @test worst_eps(telu_fast, telu, x64) <= 3 # 2 + @test mean_eps(telu_fast, telu, x64) < 0.14 # 0.1338 + @test worst_eps(telu_fast, telu, x64) <= 4 # 3 @test telu_fast.(xbig[1:end-1]) ≈ telu.(xbig[1:end-1]) @test telu_fast.(-xbig[1:end-1]) ≈ telu.(-xbig[1:end-1]) @@ -279,11 +279,11 @@ end mean_eps(deriv_telu, deriv_telu, x64) # 0.09304 worst_eps(deriv_telu, deriv_telu, x64) # 2 - @test mean_eps(deriv_telu_fast, deriv_telu, x64) < 2.1 # 2.05944 - @test worst_eps(deriv_telu_fast, deriv_telu, x64) <= 29 # 28 + @test mean_eps(_deriv_telu_fast, deriv_telu, x64) < 4.1 # 4.06396 + @test worst_eps(_deriv_telu_fast, deriv_telu, x64) <= 125 # 120 - @test deriv_telu_fast.(xbig[1:end-1]) ≈ deriv_telu.(xbig[1:end-1]) - @test deriv_telu_fast.(-xbig[1:end-1]) ≈ deriv_telu.(-xbig[1:end-1]) + @test _deriv_telu_fast.(xbig[1:end-1]) ≈ deriv_telu.(xbig[1:end-1]) + @test _deriv_telu_fast.(-xbig[1:end-1]) ≈ deriv_telu.(-xbig[1:end-1]) end end @@ -335,11 +335,11 @@ end mean_eps(deriv_telu, deriv_telu, x32) # 0.07228 worst_eps(deriv_telu, deriv_telu, x32) # 1 - @test mean_eps(deriv_telu_fast, deriv_telu, x32) < 0.69 # 0.68772 - @test worst_eps(deriv_telu_fast, deriv_telu, x32) <= 11 # 10 + @test mean_eps(_deriv_telu_fast, deriv_telu, x32) < 2.4 # 2.31772 + @test worst_eps(_deriv_telu_fast, deriv_telu, x32) <= 70 # 66 - @test deriv_telu_fast.(xbig32[1:end-1]) ≈ deriv_telu.(xbig32[1:end-1]) - @test deriv_telu_fast.(-xbig32[1:end-1]) ≈ deriv_telu.(-xbig32[1:end-1]) + @test _deriv_telu_fast.(xbig32[1:end-1]) ≈ deriv_telu.(xbig32[1:end-1]) + @test _deriv_telu_fast.(-xbig32[1:end-1]) ≈ deriv_telu.(-xbig32[1:end-1]) end end From 72ff5df1602c2fdc0aef320bde8a48c937f04d1f Mon Sep 17 00:00:00 2001 From: Mao Zeng Date: Tue, 7 Jan 2025 22:10:33 +0000 Subject: [PATCH 3/4] Rewrite code to be friendly with dual numbers. --- src/activations.jl | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/activations.jl b/src/activations.jl index 93a85ab8a..7eb7eb4d8 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -805,26 +805,41 @@ const deriv_telu_taylor_expansion = (tanh(1.0), 8*exp(1)^2 / (1+exp(1)^2)^2) # Various cutoffs for numerical evaluations of telu'(x) const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64)) const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = -log(sqrt_eps_f16), -log(sqrt_eps_f32), -log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow -@inline small_x_cutoff_deriv_telu(::Float16) = sqrt_eps_f16 -@inline small_x_cutoff_deriv_telu(::Float32) = sqrt_eps_f32 -@inline small_x_cutoff_deriv_telu(::Float64) = sqrt_eps_f64 -@inline small_x_cutoff_deriv_telu(::T) where T <: AbstractFloat = sqrt(eps(T)) -@inline minus_log_cutoff(::Float16) = minus_log_cutoff_f16 -@inline minus_log_cutoff(::Float32) = minus_log_cutoff_f32 -@inline minus_log_cutoff(::Float64) = minus_log_cutoff_f64 -@inline minus_log_cutoff(::T) where T <: AbstractFloat = -log(small_x_cutoff_deriv_telu(zero(T))) - -@inline function _deriv_telu_taylor_expansion(x::T) where {T <: Union{Float16, Float32, Float64}} - convert(T, deriv_telu_taylor_expansion[1]) + x * convert(T, deriv_telu_taylor_expansion[2]) + +@inline function _deriv_telu_taylor_expansion(x::Float64) + deriv_telu_taylor_expansion[1] + x * deriv_telu_taylor_expansion[2] +end + +@inline function _deriv_telu_taylor_expansion(x::Float32) + convert(Float32, deriv_telu_taylor_expansion[1]) + x * convert(Float32, deriv_telu_taylor_expansion[2]) +end + +@inline function _deriv_telu_taylor_expansion(x::Float16) + convert(Float16, deriv_telu_taylor_expansion[1]) + x * convert(Float16, deriv_telu_taylor_expansion[2]) end @inline function _deriv_telu_taylor_expansion(x::T) where {T <: AbstractFloat} tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2 end -function deriv_telu_fast(x, Ω) - ifelse(abs(x) < small_x_cutoff_deriv_telu(x), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion - ifelse(x >= minus_log_cutoff(x), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. This cutoff is good for all types (Float16, 32, 64) in terms of both preventing overflow and maintaining accuracy +function deriv_telu_fast(x::Float64, Ω) + ifelse(abs(x) < sqrt_eps_f64, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion + ifelse(x >= minus_log_cutoff_f64, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. +end + +function deriv_telu_fast(x::Float32, Ω) + ifelse(abs(x) < sqrt_eps_f32, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion + ifelse(x >= minus_log_cutoff_f32, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. +end + +function deriv_telu_fast(x::Float16, Ω) + ifelse(abs(x) < sqrt_eps_f16, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion + ifelse(x >= minus_log_cutoff_f16, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. +end + +function deriv_telu_fast(x::T, Ω) where T <: AbstractFloat + ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion + ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. end @inline function _deriv_telu_fast(x, Ω) From 8701e7eb6b39c95b09bf521fc08ac027169b601c Mon Sep 17 00:00:00 2001 From: Mao Zeng Date: Tue, 7 Jan 2025 23:05:40 +0000 Subject: [PATCH 4/4] Use eps(T) for float type T to control cutoffs for telu derivative evaluations. --- src/activations.jl | 38 ++------------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/src/activations.jl b/src/activations.jl index 7eb7eb4d8..db5fd8443 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -799,45 +799,11 @@ function deriv_telu(x) tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2 end -# 0th and 1st order Taylor expansion for telu'(x) around x=0 -const deriv_telu_taylor_expansion = (tanh(1.0), 8*exp(1)^2 / (1+exp(1)^2)^2) - -# Various cutoffs for numerical evaluations of telu'(x) -const sqrt_eps_f16, sqrt_eps_f32, sqrt_eps_f64 = sqrt(eps(Float16)), sqrt(eps(Float32)), sqrt(eps(Float64)) -const minus_log_cutoff_f16, minus_log_cutoff_f32, minus_log_cutoff_f64 = -log(sqrt_eps_f16), -log(sqrt_eps_f32), -log(sqrt_eps_f64) # positive cutoff to e.g. prevent `exp` from overflow - -@inline function _deriv_telu_taylor_expansion(x::Float64) - deriv_telu_taylor_expansion[1] + x * deriv_telu_taylor_expansion[2] -end - -@inline function _deriv_telu_taylor_expansion(x::Float32) - convert(Float32, deriv_telu_taylor_expansion[1]) + x * convert(Float32, deriv_telu_taylor_expansion[2]) -end - -@inline function _deriv_telu_taylor_expansion(x::Float16) - convert(Float16, deriv_telu_taylor_expansion[1]) + x * convert(Float16, deriv_telu_taylor_expansion[2]) -end - -@inline function _deriv_telu_taylor_expansion(x::T) where {T <: AbstractFloat} +@inline function _deriv_telu_taylor_expansion(x::T) where T tanh(one(T)) + x * 8*exp(one(T))^2 / (one(T)+exp(one(T))^2)^2 end -function deriv_telu_fast(x::Float64, Ω) - ifelse(abs(x) < sqrt_eps_f64, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion - ifelse(x >= minus_log_cutoff_f64, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. -end - -function deriv_telu_fast(x::Float32, Ω) - ifelse(abs(x) < sqrt_eps_f32, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion - ifelse(x >= minus_log_cutoff_f32, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. -end - -function deriv_telu_fast(x::Float16, Ω) - ifelse(abs(x) < sqrt_eps_f16, _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion - ifelse(x >= minus_log_cutoff_f16, one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. -end - -function deriv_telu_fast(x::T, Ω) where T <: AbstractFloat +function deriv_telu_fast(x::T, Ω) where T ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow. end