Skip to content

Commit 3368fa3

Browse files
Edward2 Teamedward-bot
Edward2 Team
authored andcommitted
Adds util to compute binary predictive posterior variance
PiperOrigin-RevId: 339414502
1 parent 7321555 commit 3368fa3

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

edward2/tensorflow/layers/utils.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,27 @@ def smart_constant_value(pred):
345345
return pred_value
346346

347347

348+
def mean_field_binary_predictive_variance(logits, covmat, mean_field_factor=1.):
349+
"""Compute predictive variance for Laplace-approximated logit posterior, assuming sigmoid link.
350+
351+
Arguments:
352+
logits: A float tensor of shape (batch_size, num_classes).
353+
covmat: A float tensor of shape (batch_size, batch_size).
354+
mean_field_factor: The scale factor for mean-field approximation, used to
355+
adjust the influence of posterior variance in posterior mean
356+
approximation.
357+
358+
Returns:
359+
Mean-field posterior variance.
360+
361+
"""
362+
logits_scale = tf.sqrt(1. + tf.linalg.diag_part(covmat) * mean_field_factor)
363+
logits = logits / tf.expand_dims(logits_scale, axis=-1)
364+
posterior_mean = tf.sigmoid(tf.squeeze(logits, axis=(1,)))
365+
366+
return posterior_mean * (1 - posterior_mean) * (1 / logits_scale)
367+
368+
348369
def mean_field_logits(logits, covmat, mean_field_factor=1.):
349370
"""Adjust the SNGP logits so its softmax approximates posterior mean [1].
350371
@@ -356,7 +377,7 @@ def mean_field_logits(logits, covmat, mean_field_factor=1.):
356377
approximation.
357378
358379
Returns:
359-
True or False if `pred` has a constant boolean value, None otherwise.
380+
Calibrated logits.
360381
361382
"""
362383
logits_scale = tf.sqrt(1. + tf.linalg.diag_part(covmat) * mean_field_factor)

0 commit comments

Comments
 (0)