-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Special case Inf and -Inf in truncated #1467
Conversation
Hmm I think it is a bit suboptimal to have to work around such ForwardDiff specific issues since it makes the code more complex without any advantage (?) in other use cases. Does the problem occur also with the NaN-safe setting of ForwardDiff? |
Could you please elaborate on this? |
There's a NaN-safe mode: https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Fixing-NaN/Inf-Issues It became a bit easier to use with Preferences (before one had to modify ForwardDiff; Chris even created a nansafe branch to make it a bit easier). |
Same behaviour with the NaN-safe mode. |
Codecov Report
@@ Coverage Diff @@
## master #1467 +/- ##
==========================================
- Coverage 84.44% 84.43% -0.01%
==========================================
Files 124 124
Lines 7475 7480 +5
==========================================
+ Hits 6312 6316 +4
- Misses 1163 1164 +1
Continue to review full report at Codecov.
|
I generalised the fix here to work with any |
@@ -26,10 +26,17 @@ function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real} | |||
else | |||
logcdf(d, l) | |||
end | |||
_T = typeof(loglcdf) | |||
if isfinite(loglcdf) && l == -Inf | |||
loglcdf = convert(_T, -Inf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why this fixes the problem? It shouldn't change the value of loglcdf
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't change its value but it changes the partial to 0 from NaN which gives the correct behaviour.
julia> using ForwardDiff
julia> d = ForwardDiff.Dual(1.0, 1.0)
Dual{Nothing}(1.0,1.0)
julia> convert(typeof(d), Inf)
Dual{Nothing}(Inf,0.0)
bump |
I'm not happy with this PR as it is a quite strange workaround that only makes sense for ForwardDiff. At least it would be good to add some comments for readers about these checks, add tests for the bug that it fixes, and bump the version. I'd like to check and see if there's a cleaner approach but it might take a few more days since I'm moving and don't have access to my computer right now. |
I can't reproduce this, NaN-safe mode fixes the issue for me (latest releases of both packages): julia> using ForwardDiff, Distributions
julia> d = truncated(Normal(ForwardDiff.Dual(0.0, 1.0), ForwardDiff.Dual(1.0, 0.0)), 0, Inf)
Truncated(Normal{ForwardDiff.Dual{Nothing, Float64, 1}}(μ=Dual{Nothing}(0.0,1.0), σ=Dual{Nothing}(1.0,0.0)), range=(Dual{Nothing}(0.0,0.0), Dual{Nothing}(Inf,0.0)))
julia> d.logtp.partials[1]
0.7978845608028654 Maybe you forgot to restart Julia after setting the NaN-safe preference? You can check if it is enabled by running the example from the documentation: julia> log(ForwardDiff.Dual{:tag}(0.0, 0.0))
Dual{:tag}(-Inf,0.0) The partial is I'll close this PR since it is not needed and the problem is fixed by using NaN-safe mode. |
I think the NaN-safe mode is still a sub-optimal solution to the problem. You are basically taking a performance hit everywhere in the code base to avoid a problem that only happens with truncated distributions or in a few cases. |
I think the fix in this PR though sub-optimal in a different way still gives a better overall performance. |
It's unclear whether and how large a performance hit it actually causes. Hence it might even become the default at some point, see e.g. JuliaDiff/ForwardDiff.jl#451 (comment). It's really a ForwardDiff issue, and I agree with Andreas that the default setting is just wrong: JuliaDiff/ForwardDiff.jl#451 (comment) |
This PR fixes a problem that showed up recently but I didn't track down exactly why it shows up now. Currently, the following code produces a
NaN
in the derivative of the log of the normalisation constant of aTruncated
.Since the
logpdf
of atruncated
subtractslogtp
, the derivative is always going to beNaN
at the end. This is clearly wrong and is only an artefact of how ForwardDiff/DiffRules behaviour is undefined forInf
values. For example,x.partials[1]
isNaN
in:causing the
zval
,(Inf - d.μ) / d.σ
, to giveNaN
in the derivative. This however seems to be the case even in previous versions of ForwardDiff and I haven't tracked down which change inDistributions.truncated
made this ForwardDiff behaviour significant in the above example.However, this PR feels like a half fix because:
Inf
and-Inf
but I don't need them so I didn't fix their Dual support, andInf
and-Inf
but that will require a more significant effort.That said, this PR gets the job done so I am opening it for now to get feedback.