Skip to content

Commit a6ddea0

Browse files
author
evan.zhang5
committed
fix:Added condition array for filtering 0 values
1 parent 6e24935 commit a6ddea0

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

machine_learning/loss_functions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
645645
- y_true: True class probabilities
646646
- y_pred: Predicted class probabilities
647647
648+
>>> true_labels = np.array([0, 0.4, 0.6])
649+
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
650+
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
651+
0.35835189384561095
652+
648653
>>> true_labels = np.array([0.2, 0.3, 0.5])
649654
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
650655
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
@@ -659,6 +664,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
659664
if len(y_true) != len(y_pred):
660665
raise ValueError("Input arrays must have the same length.")
661666

667+
filter_array = y_true != 0
668+
y_true = y_true[filter_array]
669+
y_pred = y_pred[filter_array]
662670
kl_loss = y_true * np.log(y_true / y_pred)
663671
return np.sum(kl_loss)
664672

0 commit comments

Comments
 (0)