diff --git a/contrib_docs/_build/pydocmd/losses.md b/contrib_docs/_build/pydocmd/losses.md
new file mode 100644
index 000000000..942beda20
--- /dev/null
+++ b/contrib_docs/_build/pydocmd/losses.md
@@ -0,0 +1,155 @@
+
keras_contrib.losses
+
+
+DSSIMObjective
+
+```python
+DSSIMObjective(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0)
+```
+Difference of Structural Similarity (DSSIM loss function).
+Clipped between 0 and 0.5
+
+Note : You should add a regularization term like a l2 loss in addition to this one.
+Note : In theano, the `kernel_size` must be a factor of the output size. So 3 could
+ not be the `kernel_size` for an output of 32.
+
+__Arguments__
+
+- __k1__: Parameter of the SSIM (default 0.01)
+- __k2__: Parameter of the SSIM (default 0.03)
+- __kernel_size__: Size of the sliding window (default 3)
+- __max_value__: Max value of the output (default 1.0)
+
+jaccard_distance
+
+```python
+jaccard_distance(y_true, y_pred, smooth=100)
+```
+Jaccard distance for semantic segmentation.
+
+Also known as the intersection-over-union loss.
+
+This loss is useful when you have unbalanced numbers of pixels within an image
+because it gives all classes equal weight. However, it is not the defacto
+standard for image segmentation.
+
+For example, assume you are trying to predict if
+each pixel is cat, dog, or background.
+You have 80% background pixels, 10% dog, and 10% cat.
+If the model predicts 100% background
+should it be be 80% right (as with categorical cross entropy)
+or 30% (with this loss)?
+
+The loss has been modified to have a smooth gradient as it converges on zero.
+This has been shifted so it converges on 0 and is smoothed to avoid exploding
+or disappearing gradient.
+
+Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
+ = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
+
+__Arguments__
+
+- __y_true__: The ground truth tensor.
+- __y_pred__: The predicted tensor
+- __smooth__: Smoothing factor. Default is 100.
+
+__Returns__
+
+ The Jaccard distance between the two tensors.
+
+__References__
+
+ - [What is a good evaluation measure for semantic segmentation?](
+ http://www.bmva.org/bmvc/2013/Papers/paper0032/paper0032.pdf)
+
+
+crf_loss
+
+```python
+crf_loss(y_true, y_pred)
+```
+General CRF loss function depending on the learning mode.
+
+__Arguments__
+
+- __y_true__: tensor with true targets.
+- __y_pred__: tensor with predicted targets.
+
+__Returns__
+
+ If the CRF layer is being trained in the join mode, returns the negative
+ log-likelihood. Otherwise returns the categorical crossentropy implemented
+ by the underlying Keras backend.
+
+__About GitHub__
+
+ If you open an issue or a pull request about CRF, please
+ add `cc @lzfelix` to notify Luiz Felix.
+
+crf_nll
+
+```python
+crf_nll(y_true, y_pred)
+```
+The negative log-likelihood for linear chain Conditional Random Field (CRF).
+
+This loss function is only used when the `layers.CRF` layer
+is trained in the "join" mode.
+
+__Arguments__
+
+- __y_true__: tensor with true targets.
+- __y_pred__: tensor with predicted targets.
+
+__Returns__
+
+ A scalar representing corresponding to the negative log-likelihood.
+
+__Raises__
+
+- `TypeError`: If CRF is not the last layer.
+
+__About GitHub__
+
+ If you open an issue or a pull request about CRF, please
+ add `cc @lzfelix` to notify Luiz Felix.
+
+dice_loss
+
+```python
+dice_loss(y_true, y_pred, smooth=1)
+```
+Dice similarity coefficient (DSC) loss.
+
+Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice
+similarity coefficient is used as a metric to evaluate the performance of
+image segmentation by comparing spatial overlap between the true and predicted
+spaces.
+
+A smoothing factor, which is by default 1, is applied to avoid dividing by
+zeros.
+
+Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2)
+ = 1 - 2 * sum(A*B) / sum(A^2 + B^2)
+
+__Arguments__
+
+- __y_true__: The ground truth tensor.
+- __y_pred__: The predicted tensor
+- __smooth__: Smoothing factor. Default is 1.
+
+__Returns__
+
+ The Dice coefficiet loss between the two tensors.
+
+__References__
+
+ - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image
+ Segmentation](https://arxiv.org/pdf/1606.04797.pdf)
+
+__About GitHub__
+
+ If you open an issue or a pull request about Dice loss, please
+ add `cc @alexbmp` to notify Seongmin Choi.
+
+
diff --git a/contrib_docs/pydocmd.yml b/contrib_docs/pydocmd.yml
index eca7ba534..67b7524fd 100644
--- a/contrib_docs/pydocmd.yml
+++ b/contrib_docs/pydocmd.yml
@@ -23,6 +23,7 @@ generate:
- keras_contrib.losses.jaccard_distance
- keras_contrib.losses.crf_loss
- keras_contrib.losses.crf_nll
+ - keras_contrib.losses.dice_loss
- optimizers.md:
- keras_contrib.optimizers:
- keras_contrib.optimizers.FTML
diff --git a/keras_contrib/losses/__init__.py b/keras_contrib/losses/__init__.py
index 37a47a804..d90790c8a 100644
--- a/keras_contrib/losses/__init__.py
+++ b/keras_contrib/losses/__init__.py
@@ -1,3 +1,4 @@
from .dssim import DSSIMObjective
from .jaccard import jaccard_distance
from .crf_losses import crf_loss, crf_nll
+from .dice import dice_loss
diff --git a/keras_contrib/losses/dice.py b/keras_contrib/losses/dice.py
new file mode 100644
index 000000000..514c8ee39
--- /dev/null
+++ b/keras_contrib/losses/dice.py
@@ -0,0 +1,40 @@
+from keras import backend as K
+
+
+def dice_loss(y_true, y_pred, smooth=1):
+ """Dice similarity coefficient (DSC) loss.
+
+ Essentially 1 minus the Dice similarity coefficient (DSC). Here, the Dice
+ similarity coefficient is used as a metric to evaluate the performance of
+ image segmentation by comparing spatial overlap between the true and
+ predicted spaces.
+
+ A smoothing factor, which is by default 1, is applied to avoid dividing by
+ zeros.
+
+ Dice loss = 1 - (2 * |X & Y|)/ (X^2 + Y^2)
+ = 1 - 2 * sum(A*B) / sum(A^2 + B^2)
+
+ # Arguments
+ y_true: The ground truth tensor.
+ y_pred: The predicted tensor
+ smooth: Smoothing factor. Default is 1.
+
+ # Returns
+ The Dice coefficiet loss between the two tensors.
+
+ # References
+ - [V-Net: Fully Convolutional Neural Networks for Volumetric Medical
+ Image Segmentation](https://arxiv.org/pdf/1606.04797.pdf)
+
+ # About GitHub
+ If you open an issue or a pull request about Dice loss, please
+ add `cc @alexbmp` to notify Seongmin Choi
+
+ """
+ y_true_flat, y_pred_flat = K.flatten(y_true), K.flatten(y_pred)
+ dice_nom = 2 * K.sum(y_true_flat * y_pred_flat)
+ dice_denom = K.sum(K.square(y_true_flat) + K.square(y_pred_flat))
+ dice_coef = (dice_nom + smooth) / (dice_denom + smooth)
+
+ return 1 - dice_coef
diff --git a/tests/keras_contrib/losses/dice_test.py b/tests/keras_contrib/losses/dice_test.py
new file mode 100644
index 000000000..8b82b5444
--- /dev/null
+++ b/tests/keras_contrib/losses/dice_test.py
@@ -0,0 +1,38 @@
+import pytest
+
+from keras_contrib.losses import dice_loss
+from keras_contrib.utils.test_utils import is_tf_keras
+from keras import backend as K
+import numpy as np
+
+
+def test_dice_loss_shapes_scalar():
+ y_true = np.random.randn(3, 4)
+ y_pred = np.random.randn(3, 4)
+
+ L = dice_loss(
+ K.variable(y_true),
+ K.variable(y_pred), )
+ assert K.is_tensor(L), 'should be a Tensor'
+ assert L.shape == ()
+ assert K.eval(L).shape == ()
+
+
+def test_dice_loss_for_same_array():
+ y_true = np.random.randn(3, 4)
+ y_pred = y_true.copy()
+
+ L = dice_loss(
+ K.variable(y_true),
+ K.variable(y_pred), )
+ assert K.eval(L) == 0, 'loss should be zero'
+
+
+def test_dice_loss_for_zero_array():
+ y_true = np.array([1])
+ y_pred = np.array([0])
+
+ L = dice_loss(
+ K.variable(y_true),
+ K.variable(y_pred), )
+ assert K.eval(L) == 0.5, 'loss should equal 0.5'