Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return NaN for negative ModeResult variance estimates #2471

Merged
merged 6 commits into from
Feb 18, 2025

Conversation

frankier
Copy link
Contributor

Here's a modified example that gets negative estimates for variance of some parameters (coefficients_versicolor[3]):

using Turing
using RDatasets
using StatsPlots
using MLDataUtils: shuffleobs, splitobs, rescale!
using NNlib: softmax
using FillArrays
using LinearAlgebra
using Random
Random.seed!(0);

using Optim
using StatsBase

data = RDatasets.dataset("datasets", "iris");
data[rand(1:size(data, 1), 20), :]
species = ["setosa", "versicolor", "virginica"]
data[!, :Species_index] = indexin(data[!, :Species], species)
data[rand(1:size(data, 1), 20), [:Species, :Species_index]]
trainset, testset = splitobs(shuffleobs(data), 0.5)
features = [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]
target = :Species_index

train_features = Matrix(trainset[!, features])
test_features = Matrix(testset[!, features])
train_target = trainset[!, target]
test_target = testset[!, target]

μ, σ = rescale!(train_features; obsdim=1)
rescale!(test_features, μ, σ; obsdim=1);

@model function logistic_regression(x, y, σ)
    n = size(x, 1)
    length(y) == n ||
        throw(DimensionMismatch("number of observations in `x` and `y` is not equal"))

    # Priors of intercepts and coefficients.
    intercept_versicolor ~ Normal(0, σ)
    intercept_virginica ~ Normal(0, σ)
    coefficients_versicolor ~ MvNormal(Zeros(4), σ^2 * I)
    coefficients_virginica ~ MvNormal(Zeros(4), σ^2 * I)

    # Compute the likelihood of the observations.
    values_versicolor = intercept_versicolor .+ x * coefficients_versicolor
    values_virginica = intercept_virginica .+ x * coefficients_virginica
    for i in 1:n
        # the 0 corresponds to the base category `setosa`
        v = softmax([0, values_versicolor[i], values_virginica[i]])
        y[i] ~ Categorical(v)
    end
end;

model = logistic_regression(train_features, train_target, 1)
mle_estimate = Optim.optimize(model, MLE())
println(coeftable(mle_estimate))

Without this PR, this will throw a DomainError in coeftable when calling getting the stderr of coefficients_versicolor[3].

@frankier
Copy link
Contributor Author

This is related to #2048

I don't fully agree with the conclusion that there is nothing to fix in Turing.jl here.

Propagating a NaN makes it easier to inspect the coeftable and see that something has gone wrong with the optimisation process.

Copy link

codecov bot commented Jan 20, 2025

Codecov Report

Attention: Patch coverage is 93.33333% with 2 lines in your changes missing coverage. Please review.

Project coverage is 84.59%. Comparing base (ddd74b1) to head (7503f60).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
src/optimisation/Optimisation.jl 93.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2471      +/-   ##
==========================================
+ Coverage   84.45%   84.59%   +0.13%     
==========================================
  Files          21       21              
  Lines        1570     1597      +27     
==========================================
+ Hits         1326     1351      +25     
- Misses        244      246       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Jan 20, 2025

Pull Request Test Coverage Report for Build 13376341310

Details

  • 28 of 30 (93.33%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.3%) to 76.116%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/optimisation/Optimisation.jl 28 30 93.33%
Totals Coverage Status
Change from base Build 13224946772: 0.3%
Covered Lines: 1211
Relevant Lines: 1591

💛 - Coveralls

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @frankier. It looks like a nice improvement!

@mhauru
Copy link
Member

mhauru commented Jan 21, 2025

Thanks @frankier, I agree that the current situation where coeftable fails with DomainError is not optimal, and this is an improvement. I wonder if we could be even more explicit though, and in our method for coeftable, catch the DomainError, print out a warning explaining that the solution seems to have negative variance and thus you should be very suspicious of your result, and then return a table with stderr as NaN. @frankier, as a user, do you think that would be helpful? This would also save us introducing a new dependency that we only use on a single line.

@frankier
Copy link
Contributor Author

Yes, I think on balance that would be better.

I think there is also the possibility of getting a SingularException in inv, which I guess also indicates model identifiability (and thus optimization) problems. So I guess it's better to catch these, aggregate them and report how it failed alongside the table.

I'll update this PR to work this way soon.

@frankier frankier force-pushed the neg-var-moderesult branch 2 times, most recently from 3ad099c to 937c1b6 Compare February 10, 2025 15:34
@frankier frankier requested a review from yebai February 10, 2025 15:34
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking really nice @frankier! I had a few small proposals. Let me know once you're done making edits and I'm happy to merge. Also, I realise you're helping out on a volunteer basis, so if you don't have time to attend to my comments that's fine too, we can merge as is and I can add a couple of tests myself and call it done.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presuming tests pass, happy to merge when you are @frankier.

@frankier
Copy link
Contributor Author

Learnt something new from the test. LGTM!

@mhauru
Copy link
Member

mhauru commented Feb 18, 2025

The failing tests are the known x86 OOM issues we are working on, nothing to do with this PR. Merging.

Thanks so much @frankier! Always a pleasure to receive a community PR, and especially so when it's so well written.

@mhauru mhauru merged commit 7b43f58 into TuringLang:master Feb 18, 2025
58 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants