Skip to content

Commit 7b43f58

Browse files
frankiermhaurusunxd3
authored
Return NaN for negative ModeResult variance estimates (#2471)
* Return NaN for negative ModeResult variance estimates * Apply suggestions from mhauru Co-authored-by: Markus Hauru <[email protected]> * Add doc to StatsBase.coefTable(::ModeResult, ...) for numerrors_warnonly * Add test for same coeftable with/without numerrors_warnonly * Add a test for coeftable with negative variance --------- Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Markus Hauru <[email protected]>
1 parent ddd74b1 commit 7b43f58

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

src/optimisation/Optimisation.jl

+58-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using Printf: Printf
1616
using ForwardDiff: ForwardDiff
1717
using StatsAPI: StatsAPI
1818
using Statistics: Statistics
19+
using LinearAlgebra: LinearAlgebra
1920

2021
export maximum_a_posteriori, maximum_likelihood
2122
# The MAP and MLE exports are only needed for the Optim.jl interface.
@@ -228,11 +229,61 @@ end
228229

229230
# Various StatsBase methods for ModeResult
230231

231-
function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
232+
"""
233+
StatsBase.coeftable(m::ModeResult; level::Real=0.95, numerrors_warnonly::Bool=true)
234+
235+
236+
Return a table with coefficients and related statistics of the model. level determines the
237+
level for confidence intervals (by default, 95%).
238+
239+
In case the `numerrors_warnonly` argument is true (the default) numerical errors encountered
240+
during the computation of the standard errors will be caught and reported in an extra
241+
"Error notes" column.
242+
"""
243+
function StatsBase.coeftable(m::ModeResult; level::Real=0.95, numerrors_warnonly::Bool=true)
232244
# Get columns for coeftable.
233245
terms = string.(StatsBase.coefnames(m))
234246
estimates = m.values.array[:, 1]
235-
stderrors = StatsBase.stderror(m)
247+
# If numerrors_warnonly is true, and if either the information matrix is singular or has
248+
# negative entries on its diagonal, then `notes` will be a list of strings for each
249+
# value in `m.values`, explaining why the standard error is NaN.
250+
notes = nothing
251+
local stderrors
252+
if numerrors_warnonly
253+
infmat = StatsBase.informationmatrix(m)
254+
local vcov
255+
try
256+
vcov = inv(infmat)
257+
catch e
258+
if isa(e, LinearAlgebra.SingularException)
259+
stderrors = fill(NaN, length(m.values))
260+
notes = fill("Information matrix is singular", length(m.values))
261+
else
262+
rethrow(e)
263+
end
264+
else
265+
vars = LinearAlgebra.diag(vcov)
266+
stderrors = eltype(vars)[]
267+
if any(x -> x < 0, vars)
268+
notes = []
269+
end
270+
for var in vars
271+
if var >= 0
272+
push!(stderrors, sqrt(var))
273+
if notes !== nothing
274+
push!(notes, "")
275+
end
276+
else
277+
push!(stderrors, NaN)
278+
if notes !== nothing
279+
push!(notes, "Negative variance")
280+
end
281+
end
282+
end
283+
end
284+
else
285+
stderrors = StatsBase.stderror(m)
286+
end
236287
zscore = estimates ./ stderrors
237288
p = map(z -> StatsAPI.pvalue(Distributions.Normal(), z; tail=:both), zscore)
238289

@@ -244,7 +295,7 @@ function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
244295
level_ = 100 * level
245296
level_percentage = isinteger(level_) ? Int(level_) : level_
246297

247-
cols = [estimates, stderrors, zscore, p, ci_low, ci_high]
298+
cols = Vector[estimates, stderrors, zscore, p, ci_low, ci_high]
248299
colnms = [
249300
"Coef.",
250301
"Std. Error",
@@ -253,6 +304,10 @@ function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
253304
"Lower $(level_percentage)%",
254305
"Upper $(level_percentage)%",
255306
]
307+
if notes !== nothing
308+
push!(cols, notes)
309+
push!(colnms, "Error notes")
310+
end
256311
return StatsBase.CoefTable(cols, colnms, terms)
257312
end
258313

test/optimisation/Optimisation.jl

+59
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,65 @@ using Turing
635635
maximum_a_posteriori(m; adtype=adbackend)
636636
end
637637
end
638+
639+
@testset "Collinear coeftable" begin
640+
xs = [-1.0, 0.0, 1.0]
641+
ys = [0.0, 0.0, 0.0]
642+
643+
@model function collinear(x, y)
644+
a ~ Normal(0, 1)
645+
b ~ Normal(0, 1)
646+
return y ~ MvNormal(a .* x .+ b .* x, 1)
647+
end
648+
649+
model = collinear(xs, ys)
650+
mle_estimate = Turing.Optimisation.estimate_mode(model, MLE())
651+
tab = coeftable(mle_estimate)
652+
@assert isnan(tab.cols[2][1])
653+
@assert tab.colnms[end] == "Error notes"
654+
@assert occursin("singular", tab.cols[end][1])
655+
end
656+
657+
@testset "Negative variance" begin
658+
# A model for which the likelihood has a saddle point at x=0, y=0.
659+
# Creating an optimisation result for this model at the x=0, y=0 results in negative
660+
# variance for one of the variables, because the variance is calculated as the
661+
# diagonal of the inverse of the Hessian.
662+
@model function saddle_model()
663+
x ~ Normal(0, 1)
664+
y ~ Normal(x, 1)
665+
Turing.@addlogprob! x^2 - y^2
666+
return nothing
667+
end
668+
m = saddle_model()
669+
ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
670+
optim_ld = Turing.Optimisation.OptimLogDensity(m, ctx)
671+
vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0])
672+
m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld)
673+
ct = coeftable(m)
674+
@assert isnan(ct.cols[2][1])
675+
@assert ct.colnms[end] == "Error notes"
676+
@assert occursin("Negative variance", ct.cols[end][1])
677+
end
678+
679+
@testset "Same coeftable with/without numerrors_warnonly" begin
680+
xs = [0.0, 1.0, 2.0]
681+
682+
@model function extranormal(x)
683+
mean ~ Normal(0, 1)
684+
return x ~ Normal(mean, 1)
685+
end
686+
687+
model = extranormal(xs)
688+
mle_estimate = Turing.Optimisation.estimate_mode(model, MLE())
689+
warnonly_coeftable = coeftable(mle_estimate; numerrors_warnonly=true)
690+
no_warnonly_coeftable = coeftable(mle_estimate; numerrors_warnonly=false)
691+
@assert warnonly_coeftable.cols == no_warnonly_coeftable.cols
692+
@assert warnonly_coeftable.colnms == no_warnonly_coeftable.colnms
693+
@assert warnonly_coeftable.rownms == no_warnonly_coeftable.rownms
694+
@assert warnonly_coeftable.pvalcol == no_warnonly_coeftable.pvalcol
695+
@assert warnonly_coeftable.teststatcol == no_warnonly_coeftable.teststatcol
696+
end
638697
end
639698

640699
end

0 commit comments

Comments
 (0)