Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special case Inf and -Inf in truncated #1467

Closed
wants to merge 4 commits into from

Conversation

mohamed82008
Copy link
Contributor

@mohamed82008 mohamed82008 commented Jan 3, 2022

This PR fixes a problem that showed up recently but I didn't track down exactly why it shows up now. Currently, the following code produces a NaN in the derivative of the log of the normalisation constant of a Truncated.

using ForwardDiff, Distributions

d = truncated(Normal(ForwardDiff.Dual(0.0, 1.0), ForwardDiff.Dual(1.0, 0.0)), 0, Inf)
d.logtp.partials[1] # NaN

Since the logpdf of a truncated subtracts logtp, the derivative is always going to be NaN at the end. This is clearly wrong and is only an artefact of how ForwardDiff/DiffRules behaviour is undefined for Inf values. For example, x.partials[1] is NaN in:

x = ForwardDiff.Dual(Inf, 0.0) / ForwardDiff.Dual(0.1, 0.0)

causing the zval, (Inf - d.μ) / d.σ, to give NaN in the derivative. This however seems to be the case even in previous versions of ForwardDiff and I haven't tracked down which change in Distributions.truncated made this ForwardDiff behaviour significant in the above example.

However, this PR feels like a half fix because:

  1. It's not clear which change caused the above issue,
  2. There are other similar functions that may need a special case for Inf and -Inf but I don't need them so I didn't fix their Dual support, and
  3. It feels like ForwardDiff or DiffRules should be the place to fix this or special case Inf and -Inf but that will require a more significant effort.

That said, this PR gets the job done so I am opening it for now to get feedback.

@devmotion
Copy link
Member

Hmm I think it is a bit suboptimal to have to work around such ForwardDiff specific issues since it makes the code more complex without any advantage (?) in other use cases. Does the problem occur also with the NaN-safe setting of ForwardDiff?

@mohamed82008
Copy link
Contributor Author

Does the problem occur also with the NaN-safe setting of ForwardDiff?

Could you please elaborate on this?

@devmotion
Copy link
Member

There's a NaN-safe mode: https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Fixing-NaN/Inf-Issues It became a bit easier to use with Preferences (before one had to modify ForwardDiff; Chris even created a nansafe branch to make it a bit easier).

@mohamed82008
Copy link
Contributor Author

Same behaviour with the NaN-safe mode.

@codecov-commenter
Copy link

codecov-commenter commented Jan 4, 2022

Codecov Report

Merging #1467 (e878bb1) into master (93a542d) will decrease coverage by 0.00%.
The diff coverage is 80.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1467      +/-   ##
==========================================
- Coverage   84.44%   84.43%   -0.01%     
==========================================
  Files         124      124              
  Lines        7475     7480       +5     
==========================================
+ Hits         6312     6316       +4     
- Misses       1163     1164       +1     
Impacted Files Coverage Δ
src/truncate.jl 85.18% <80.00%> (-0.35%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 93a542d...e878bb1. Read the comment docs.

@mohamed82008 mohamed82008 changed the title Special case Inf and -Inf in logcdf for Normal Special case Inf and -Inf in truncated Jan 4, 2022
@mohamed82008
Copy link
Contributor Author

I generalised the fix here to work with any Truncated.

@@ -26,10 +26,17 @@ function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real}
else
logcdf(d, l)
end
_T = typeof(loglcdf)
if isfinite(loglcdf) && l == -Inf
loglcdf = convert(_T, -Inf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this fixes the problem? It shouldn't change the value of loglcdf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't change its value but it changes the partial to 0 from NaN which gives the correct behaviour.

julia> using ForwardDiff

julia> d = ForwardDiff.Dual(1.0, 1.0)
Dual{Nothing}(1.0,1.0)

julia> convert(typeof(d), Inf)
Dual{Nothing}(Inf,0.0)

@mohamed82008
Copy link
Contributor Author

bump

@devmotion
Copy link
Member

I'm not happy with this PR as it is a quite strange workaround that only makes sense for ForwardDiff. At least it would be good to add some comments for readers about these checks, add tests for the bug that it fixes, and bump the version.

I'd like to check and see if there's a cleaner approach but it might take a few more days since I'm moving and don't have access to my computer right now.

@devmotion
Copy link
Member

Same behaviour with the NaN-safe mode.

I can't reproduce this, NaN-safe mode fixes the issue for me (latest releases of both packages):

julia> using ForwardDiff, Distributions

julia> d = truncated(Normal(ForwardDiff.Dual(0.0, 1.0), ForwardDiff.Dual(1.0, 0.0)), 0, Inf)
Truncated(Normal{ForwardDiff.Dual{Nothing, Float64, 1}}=Dual{Nothing}(0.0,1.0), σ=Dual{Nothing}(1.0,0.0)), range=(Dual{Nothing}(0.0,0.0), Dual{Nothing}(Inf,0.0)))

julia> d.logtp.partials[1]
0.7978845608028654

Maybe you forgot to restart Julia after setting the NaN-safe preference? You can check if it is enabled by running the example from the documentation:

julia> log(ForwardDiff.Dual{:tag}(0.0, 0.0))
Dual{:tag}(-Inf,0.0)

The partial is NaN if the NaN-safe mode is not enabled.

I'll close this PR since it is not needed and the problem is fixed by using NaN-safe mode.

@devmotion devmotion closed this Jan 10, 2022
@mohamed82008
Copy link
Contributor Author

I think the NaN-safe mode is still a sub-optimal solution to the problem. You are basically taking a performance hit everywhere in the code base to avoid a problem that only happens with truncated distributions or in a few cases.

@mohamed82008
Copy link
Contributor Author

I think the fix in this PR though sub-optimal in a different way still gives a better overall performance.

@devmotion
Copy link
Member

I think the NaN-safe mode is still a sub-optimal solution to the problem. You are basically taking a performance hit everywhere in the code base to avoid a problem that only happens with truncated distributions or in a few cases.

It's unclear whether and how large a performance hit it actually causes. Hence it might even become the default at some point, see e.g. JuliaDiff/ForwardDiff.jl#451 (comment). It's really a ForwardDiff issue, and I agree with Andreas that the default setting is just wrong: JuliaDiff/ForwardDiff.jl#451 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants