diff --git a/contrib_docs/_build/pydocmd/losses.md b/contrib_docs/_build/pydocmd/losses.md new file mode 100644 index 000000000..942beda20 --- /dev/null +++ b/contrib_docs/_build/pydocmd/losses.md @@ -0,0 +1,155 @@ +

keras_contrib.losses

+ + +

DSSIMObjective

+ +```python +DSSIMObjective(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0) +``` +Difference of Structural Similarity (DSSIM loss function). +Clipped between 0 and 0.5 + +Note : You should add a regularization term like a l2 loss in addition to this one. +Note : In theano, the `kernel_size` must be a factor of the output size. So 3 could + not be the `kernel_size` for an output of 32. + +__Arguments__ + +- __k1__: Parameter of the SSIM (default 0.01) +- __k2__: Parameter of the SSIM (default 0.03) +- __kernel_size__: Size of the sliding window (default 3) +- __max_value__: Max value of the output (default 1.0) + +

jaccard_distance

+ +```python +jaccard_distance(y_true, y_pred, smooth=100) +``` +Jaccard distance for semantic segmentation. + +Also known as the intersection-over-union loss. + +This loss is useful when you have unbalanced numbers of pixels within an image +because it gives all classes equal weight. However, it is not the defacto +standard for image segmentation. + +For example, assume you are trying to predict if +each pixel is cat, dog, or background. +You have 80% background pixels, 10% dog, and 10% cat. +If the model predicts 100% background +should it be be 80% right (as with categorical cross entropy) +or 30% (with this loss)? + +The loss has been modified to have a smooth gradient as it converges on zero. +This has been shifted so it converges on 0 and is smoothed to avoid exploding +or disappearing gradient. + +Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + +__Arguments__ + +- __y_true__: The ground truth tensor. +- __y_pred__: The predicted tensor +- __smooth__: Smoothing factor. Default is 100. + +__Returns__ + + The Jaccard distance between the two tensors. + +__References__ + + - [What is a good evaluation measure for semantic segmentation?]( + http://www.bmva.org/bmvc/2013/Papers/paper0032/paper0032.pdf) + + +

crf_loss

+ +```python +crf_loss(y_true, y_pred) +``` +General CRF loss function depending on the learning mode. + +__Arguments__ + +- __y_true__: tensor with true targets. +- __y_pred__: tensor with predicted targets. + +__Returns__ + + If the CRF layer is being trained in the join mode, returns the negative + log-likelihood. Otherwise returns the categorical crossentropy implemented + by the underlying Keras backend. + +__About GitHub__ + + If you open an issue or a pull request about CRF, please + add `cc @lzfelix` to notify Luiz Felix. + +

crf_nll

+ +```python +crf_nll(y_true, y_pred) +``` +The negative log-likelihood for linear chain Conditional Random Field (CRF). + +This loss function is only used when the `layers.CRF` layer +is trained in the "join" mode. + +__Arguments__ + +- __y_true__: tensor with true targets. +- __y_pred__: tensor with predicted targets. + +__Returns__ + + A scalar representing corresponding to the negative log-likelihood. + +__Raises__ + +- `TypeError`: If CRF is not the last layer. + +__About GitHub__ + + If you open an issue or a pull request about CRF, please + add `cc @lzfelix` to notify Luiz Felix. + +

dice_loss

+ +```python +dice_loss(y_true, y_pred, smooth=1) +``` +Dice similarity coefficient (DSC) loss. + +Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice +similarity coefficient is used as a metric to evaluate the performance of +image segmentation by comparing spatial overlap between the true and predicted +spaces. + +A smoothing factor, which is by default 1, is applied to avoid dividing by +zeros. + +Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2) + = 1 - 2 * sum(A*B) / sum(A^2 + B^2) + +__Arguments__ + +- __y_true__: The ground truth tensor. +- __y_pred__: The predicted tensor +- __smooth__: Smoothing factor. Default is 1. + +__Returns__ + + The Dice coefficiet loss between the two tensors. + +__References__ + + - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image + Segmentation](https://arxiv.org/pdf/1606.04797.pdf) + +__About GitHub__ + + If you open an issue or a pull request about Dice loss, please + add `cc @alexbmp` to notify Seongmin Choi. + + diff --git a/contrib_docs/pydocmd.yml b/contrib_docs/pydocmd.yml index eca7ba534..67b7524fd 100644 --- a/contrib_docs/pydocmd.yml +++ b/contrib_docs/pydocmd.yml @@ -23,6 +23,7 @@ generate: - keras_contrib.losses.jaccard_distance - keras_contrib.losses.crf_loss - keras_contrib.losses.crf_nll + - keras_contrib.losses.dice_loss - optimizers.md: - keras_contrib.optimizers: - keras_contrib.optimizers.FTML diff --git a/keras_contrib/losses/__init__.py b/keras_contrib/losses/__init__.py index 37a47a804..d90790c8a 100644 --- a/keras_contrib/losses/__init__.py +++ b/keras_contrib/losses/__init__.py @@ -1,3 +1,4 @@ from .dssim import DSSIMObjective from .jaccard import jaccard_distance from .crf_losses import crf_loss, crf_nll +from .dice import dice_loss diff --git a/keras_contrib/losses/dice.py b/keras_contrib/losses/dice.py new file mode 100644 index 000000000..514c8ee39 --- /dev/null +++ b/keras_contrib/losses/dice.py @@ -0,0 +1,40 @@ +from keras import backend as K + + +def dice_loss(y_true, y_pred, smooth=1): + """Dice similarity coefficient (DSC) loss. + + Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice + similarity coefficient is used as a metric to evaluate the performance of + image segmentation by comparing spatial overlap between the true and + predicted spaces. + + A smoothing factor, which is by default 1, is applied to avoid dividing by + zeros. + + Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2) + = 1 - 2 * sum(A*B) / sum(A^2 + B^2) + + # Arguments + y_true: The ground truth tensor. + y_pred: The predicted tensor + smooth: Smoothing factor. Default is 1. + + # Returns + The Dice coefficiet loss between the two tensors. + + # References + - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical + Image Segmentation](https://arxiv.org/pdf/1606.04797.pdf) + + # About GitHub + If you open an issue or a pull request about Dice loss, please + add `cc @alexbmp` to notify Seongmin Choi + + """ + y_true_flat, y_pred_flat = K.flatten(y_true), K.flatten(y_pred) + dice_nom = 2 * K.sum(y_true_flat * y_pred_flat) + dice_denom = K.sum(K.square(y_true_flat) + K.square(y_pred_flat)) + dice_coef = (dice_nom + smooth) / (dice_denom + smooth) + + return 1 - dice_coef diff --git a/tests/keras_contrib/losses/dice_test.py b/tests/keras_contrib/losses/dice_test.py new file mode 100644 index 000000000..8b82b5444 --- /dev/null +++ b/tests/keras_contrib/losses/dice_test.py @@ -0,0 +1,38 @@ +import pytest + +from keras_contrib.losses import dice_loss +from keras_contrib.utils.test_utils import is_tf_keras +from keras import backend as K +import numpy as np + + +def test_dice_loss_shapes_scalar(): + y_true = np.random.randn(3, 4) + y_pred = np.random.randn(3, 4) + + L = dice_loss( + K.variable(y_true), + K.variable(y_pred), ) + assert K.is_tensor(L), 'should be a Tensor' + assert L.shape == () + assert K.eval(L).shape == () + + +def test_dice_loss_for_same_array(): + y_true = np.random.randn(3, 4) + y_pred = y_true.copy() + + L = dice_loss( + K.variable(y_true), + K.variable(y_pred), ) + assert K.eval(L) == 0, 'loss should be zero' + + +def test_dice_loss_for_zero_array(): + y_true = np.array([1]) + y_pred = np.array([0]) + + L = dice_loss( + K.variable(y_true), + K.variable(y_pred), ) + assert K.eval(L) == 0.5, 'loss should equal 0.5'