diff --git a/nn/loss.py b/nn/loss.py index 1b50a817..17ef3daf 100644 --- a/nn/loss.py +++ b/nn/loss.py @@ -55,3 +55,52 @@ def cross_entropy(*, target: nn.Tensor, estimated: nn.Tensor, estimated_type: st else: raise ValueError("estimated_kind must be 'probs', 'log-probs' or 'logits'") return -nn.dot(target, log_prob, reduce=axis) + + +def kl_div(*, target: nn.Tensor, target_type: str, + estimated: nn.Tensor, estimated_type: str, + axis: Optional[nn.Dim] = None) -> nn.Tensor: + """ + Kullback-Leibler divergence (https://en.wikipedia.org/wiki/Kullback-Leibler_divergence) + + L(target, estimated) = target * log(target / estimated) + = target * (log(target) - log(estimated) + + :param target: probs, normalized. can also be sparse + :param target_type: "probs", "log-probs" or "logits" + :param estimated: probs, log-probs or logits, specified via ``estimated_type`` + :param estimated_type: "probs", "log-probs" or "logits" + :param axis: the axis to reduce over + :return: KL-div + """ + if not axis: + assert target.feature_dim + axis = target.feature_dim + + if target.data.sparse: + raise NotImplementedError(f"Sparse target {target} not supported for KL. Use cross entropy instead?") + if target_type == "probs": + log_target = nn.safe_log(target) + elif estimated_type == "log-probs": + log_target = target + elif estimated_type == "logits": + log_target = nn.log_softmax(target, axis=axis) + else: + raise ValueError("target_kind must be 'probs', 'log-probs' or 'logits'") + + if estimated_type == "probs": + log_est = nn.safe_log(estimated) + elif estimated_type == "log-probs": + log_est = estimated + elif estimated_type == "logits": + log_est = nn.log_softmax(estimated, axis=axis) + else: + raise ValueError("estimated_kind must be 'probs', 'log-probs' or 'logits'") + + # Assuming target = softmax(...): + # Using nn.exp(log_target) instead of target (but not nn.safe_exp!) + # to avoid calculating softmax twice (efficiency) + # (because nn.safe_log(target) = log_softmax(...), so a separate softmax calculation). + kl = nn.dot(nn.exp(log_target), log_target - log_est, reduce=axis) + + return kl