Skip to content

Commit 6027862

Browse files
committed
binary_cross_entropy
#38
1 parent 7d14813 commit 6027862

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

nn/loss.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
"""
1212

13-
from typing import Optional
13+
from typing import Optional, Union
1414
from .. import nn
1515

1616

@@ -57,6 +57,89 @@ def cross_entropy(*, target: nn.Tensor, estimated: nn.Tensor, estimated_type: st
5757
return -nn.dot(target, log_prob, reduce=axis)
5858

5959

60+
@nn.scoped
61+
def binary_cross_entropy(*,
62+
target: nn.Tensor,
63+
pos_estimated: nn.Tensor, pos_estimated_type: str,
64+
pos_weight: Optional[Union[float, nn.Tensor]] = None):
65+
"""
66+
Binary cross entropy, or also called sigmoid cross entropy.
67+
68+
:param target: (sparse) target labels, 0 (positive) or 1 (negative), i.e. binary.
69+
:param pos_estimated: positive class logits. probs = sigmoid(logits).
70+
:param pos_estimated_type: "logits" only supported currently
71+
:param pos_weight: weight for positive class.
72+
73+
Code and documentation partly borrowed from TensorFlow.
74+
75+
A value `pos_weight > 1` decreases the false negative count, hence increasing
76+
the recall.
77+
Conversely setting `pos_weight < 1` decreases the false positive count and
78+
increases the precision.
79+
This can be seen from the fact that `pos_weight` is introduced as a
80+
multiplicative coefficient for the positive labels term
81+
in the loss expression:
82+
83+
labels * -log(sigmoid(logits)) * pos_weight +
84+
(1 - labels) * -log(1 - sigmoid(logits))
85+
86+
For brevity, let `x = logits`, `z = labels`. The logistic loss is
87+
88+
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
89+
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
90+
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
91+
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
92+
= (1 - z) * x + log(1 + exp(-x))
93+
= x - x * z + log(1 + exp(-x))
94+
95+
For x < 0, to avoid overflow in exp(-x), we reformulate the above
96+
97+
x - x * z + log(1 + exp(-x))
98+
= log(exp(x)) - x * z + log(1 + exp(-x))
99+
= - x * z + log(1 + exp(x))
100+
101+
Hence, to ensure stability and avoid overflow, the implementation uses this
102+
equivalent formulation
103+
104+
max(x, 0) - x * z + log(1 + exp(-abs(x)))
105+
"""
106+
if pos_estimated_type != "logits":
107+
raise NotImplementedError(
108+
f"binary_cross_entropy, pos_estimated_type {pos_estimated_type!r}, only 'logits' supported")
109+
logits = pos_estimated
110+
111+
if pos_weight is not None:
112+
# Code adapted from tf.nn.weighted_cross_entropy_with_logits.
113+
# The logistic loss formula from above is
114+
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
115+
# For x < 0, a more numerically stable formula is
116+
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
117+
# To avoid branching, we use the combined version
118+
# (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
119+
log_weight = 1 + (pos_weight - 1) * target
120+
return (
121+
(1 - target) * logits +
122+
log_weight * (nn.log1p(nn.exp(-nn.abs(logits))) + nn.relu(-logits))
123+
)
124+
125+
# Code adapted from tf.nn.sigmoid_cross_entropy_with_logits.
126+
# The logistic loss formula from above is
127+
# x - x * z + log(1 + exp(-x))
128+
# For x < 0, a more numerically stable formula is
129+
# -x * z + log(1 + exp(x))
130+
# Note that these two expressions can be combined into the following:
131+
# max(x, 0) - x * z + log(1 + exp(-abs(x)))
132+
# To allow computing gradients at zero, we define custom versions of max and
133+
# abs functions.
134+
cond = (logits >= 0)
135+
relu_logits = nn.where(cond, logits, 0)
136+
neg_abs_logits = nn.where(cond, -logits, logits) # pylint: disable=invalid-unary-operand-type
137+
return (
138+
relu_logits - logits * target +
139+
nn.log1p(nn.exp(neg_abs_logits))
140+
)
141+
142+
60143
@nn.scoped
61144
def kl_div(*, target: nn.Tensor, target_type: str,
62145
estimated: nn.Tensor, estimated_type: str,

0 commit comments

Comments
 (0)