|
10 | 10 |
|
11 | 11 | """
|
12 | 12 |
|
13 |
| -from typing import Optional |
| 13 | +from typing import Optional, Union |
14 | 14 | from .. import nn
|
15 | 15 |
|
16 | 16 |
|
@@ -57,6 +57,89 @@ def cross_entropy(*, target: nn.Tensor, estimated: nn.Tensor, estimated_type: st
|
57 | 57 | return -nn.dot(target, log_prob, reduce=axis)
|
58 | 58 |
|
59 | 59 |
|
| 60 | +@nn.scoped |
| 61 | +def binary_cross_entropy(*, |
| 62 | + target: nn.Tensor, |
| 63 | + pos_estimated: nn.Tensor, pos_estimated_type: str, |
| 64 | + pos_weight: Optional[Union[float, nn.Tensor]] = None): |
| 65 | + """ |
| 66 | + Binary cross entropy, or also called sigmoid cross entropy. |
| 67 | +
|
| 68 | + :param target: (sparse) target labels, 0 (positive) or 1 (negative), i.e. binary. |
| 69 | + :param pos_estimated: positive class logits. probs = sigmoid(logits). |
| 70 | + :param pos_estimated_type: "logits" only supported currently |
| 71 | + :param pos_weight: weight for positive class. |
| 72 | +
|
| 73 | + Code and documentation partly borrowed from TensorFlow. |
| 74 | +
|
| 75 | + A value `pos_weight > 1` decreases the false negative count, hence increasing |
| 76 | + the recall. |
| 77 | + Conversely setting `pos_weight < 1` decreases the false positive count and |
| 78 | + increases the precision. |
| 79 | + This can be seen from the fact that `pos_weight` is introduced as a |
| 80 | + multiplicative coefficient for the positive labels term |
| 81 | + in the loss expression: |
| 82 | +
|
| 83 | + labels * -log(sigmoid(logits)) * pos_weight + |
| 84 | + (1 - labels) * -log(1 - sigmoid(logits)) |
| 85 | +
|
| 86 | + For brevity, let `x = logits`, `z = labels`. The logistic loss is |
| 87 | +
|
| 88 | + z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) |
| 89 | + = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) |
| 90 | + = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) |
| 91 | + = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) |
| 92 | + = (1 - z) * x + log(1 + exp(-x)) |
| 93 | + = x - x * z + log(1 + exp(-x)) |
| 94 | +
|
| 95 | + For x < 0, to avoid overflow in exp(-x), we reformulate the above |
| 96 | +
|
| 97 | + x - x * z + log(1 + exp(-x)) |
| 98 | + = log(exp(x)) - x * z + log(1 + exp(-x)) |
| 99 | + = - x * z + log(1 + exp(x)) |
| 100 | +
|
| 101 | + Hence, to ensure stability and avoid overflow, the implementation uses this |
| 102 | + equivalent formulation |
| 103 | +
|
| 104 | + max(x, 0) - x * z + log(1 + exp(-abs(x))) |
| 105 | + """ |
| 106 | + if pos_estimated_type != "logits": |
| 107 | + raise NotImplementedError( |
| 108 | + f"binary_cross_entropy, pos_estimated_type {pos_estimated_type!r}, only 'logits' supported") |
| 109 | + logits = pos_estimated |
| 110 | + |
| 111 | + if pos_weight is not None: |
| 112 | + # Code adapted from tf.nn.weighted_cross_entropy_with_logits. |
| 113 | + # The logistic loss formula from above is |
| 114 | + # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) |
| 115 | + # For x < 0, a more numerically stable formula is |
| 116 | + # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x |
| 117 | + # To avoid branching, we use the combined version |
| 118 | + # (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) |
| 119 | + log_weight = 1 + (pos_weight - 1) * target |
| 120 | + return ( |
| 121 | + (1 - target) * logits + |
| 122 | + log_weight * (nn.log1p(nn.exp(-nn.abs(logits))) + nn.relu(-logits)) |
| 123 | + ) |
| 124 | + |
| 125 | + # Code adapted from tf.nn.sigmoid_cross_entropy_with_logits. |
| 126 | + # The logistic loss formula from above is |
| 127 | + # x - x * z + log(1 + exp(-x)) |
| 128 | + # For x < 0, a more numerically stable formula is |
| 129 | + # -x * z + log(1 + exp(x)) |
| 130 | + # Note that these two expressions can be combined into the following: |
| 131 | + # max(x, 0) - x * z + log(1 + exp(-abs(x))) |
| 132 | + # To allow computing gradients at zero, we define custom versions of max and |
| 133 | + # abs functions. |
| 134 | + cond = (logits >= 0) |
| 135 | + relu_logits = nn.where(cond, logits, 0) |
| 136 | + neg_abs_logits = nn.where(cond, -logits, logits) # pylint: disable=invalid-unary-operand-type |
| 137 | + return ( |
| 138 | + relu_logits - logits * target + |
| 139 | + nn.log1p(nn.exp(neg_abs_logits)) |
| 140 | + ) |
| 141 | + |
| 142 | + |
60 | 143 | @nn.scoped
|
61 | 144 | def kl_div(*, target: nn.Tensor, target_type: str,
|
62 | 145 | estimated: nn.Tensor, estimated_type: str,
|
|
0 commit comments