diff --git a/edward2/tensorflow/layers/utils.py b/edward2/tensorflow/layers/utils.py index dfaaa676..4f08400b 100644 --- a/edward2/tensorflow/layers/utils.py +++ b/edward2/tensorflow/layers/utils.py @@ -345,6 +345,27 @@ def smart_constant_value(pred): return pred_value +def mean_field_binary_predictive_variance(logits, covmat, mean_field_factor=1.): + """Compute predictive variance for Laplace-approximated logit posterior, assuming sigmoid link. + + Arguments: + logits: A float tensor of shape (batch_size, num_classes). + covmat: A float tensor of shape (batch_size, batch_size). + mean_field_factor: The scale factor for mean-field approximation, used to + adjust the influence of posterior variance in posterior mean + approximation. + + Returns: + Mean-field posterior variance. + + """ + logits_scale = tf.sqrt(1. + tf.linalg.diag_part(covmat) * mean_field_factor) + logits = logits / tf.expand_dims(logits_scale, axis=-1) + posterior_mean = tf.sigmoid(tf.squeeze(logits, axis=(1,))) + + return posterior_mean * (1 - posterior_mean) * (1 / logits_scale) + + def mean_field_logits(logits, covmat, mean_field_factor=1.): """Adjust the SNGP logits so its softmax approximates posterior mean [1]. @@ -356,7 +377,7 @@ def mean_field_logits(logits, covmat, mean_field_factor=1.): approximation. Returns: - True or False if `pred` has a constant boolean value, None otherwise. + Calibrated logits. """ logits_scale = tf.sqrt(1. + tf.linalg.diag_part(covmat) * mean_field_factor)