Skip to content

Commit cd01127

Browse files
authored
[Op] Fix nll_loss (awslabs#113)
* fix * test * address comments
1 parent 650175b commit cd01127

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

python/raf/_tvm_op/loss.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,16 @@ def smooth_l1_loss_dtrue_compute(attr, inputs, output_type): # pylint: disable=
8383
def nll_loss_compute(attrs, inputs, output_type): # pylint: disable=unused-argument
8484
true, pred = inputs
8585
n, c = pred.shape
86+
dtype = pred.dtype
87+
if dtype == "float16":
88+
pred = pred.astype("float32")
8689

8790
if true.ndim == 1: # one-host label encoding
8891

8992
def fcompute_one_hot(i): # pylint: disable=unused-argument
9093
return -pred[i, true[i]] / n
9194

92-
loss = _tvm.te.compute((n,), fcompute_one_hot)
95+
loss = _tvm.te.compute((n,), fcompute_one_hot, tag=_tvm.topi.tag.INJECTIVE)
9396
loss = _topi.sum(loss, axis=[0], keepdims=True)
9497
else: # sparse label encoding
9598

@@ -98,12 +101,12 @@ def fcompute_sparse(x): # pylint: disable=unused-argument
98101
redc = _tvm.te.reduce_axis((0, c), name="rc")
99102
return _tvm.te.sum(-pred[redn, redc] * true[redn, redc] / n, axis=[redc, redn])
100103

101-
loss = _tvm.te.compute((1,), fcompute_sparse)
104+
loss = _tvm.te.compute((1,), fcompute_sparse, tag=_tvm.topi.tag.COMM_REDUCE)
102105

103-
return [loss]
106+
return [loss.astype(dtype)]
104107

105108

106-
_reg.register_injective_schedule("raf.op.tvm.nll_loss")
109+
_reg.register_reduce_schedule("raf.op.tvm.nll_loss")
107110

108111

109112
@register_compute("raf.op.tvm.nll_loss_dpred")

tests/python/op/tvm/test_tvm_loss.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ def forward(self, y_true, y_pred): # pylint: disable=no-self-use
5757
@pytest.mark.parametrize("device", get_testable_devices())
5858
@pytest.mark.parametrize("n", [3, 7])
5959
@pytest.mark.parametrize("c", [2, 6])
60+
@pytest.mark.parametrize("dtype", ["float32", "float16"])
6061
@pytest.mark.parametrize("one_hot_label", [True, False])
61-
def test_nll_loss(device, n, c, one_hot_label):
62+
def test_nll_loss(device, n, c, dtype, one_hot_label):
63+
if device == "cpu" and dtype == "float16":
64+
pytest.skip("PyTorch nll_loss does not support float16 when using CPU.")
65+
6266
class TestModel(raf.Model):
6367
def build(self):
6468
pass
@@ -68,10 +72,10 @@ def forward(self, y_true, y_pred): # pylint: disable=no-self-use
6872
return raf.nll_loss(y_true=y_true, y_pred=y_pred)
6973

7074
model = TestModel()
71-
m_pred, t_pred = randn_torch((n, c), device=device, requires_grad=True)
75+
m_pred, t_pred = randn_torch((n, c), dtype=dtype, device=device, requires_grad=True)
7276
m_true, np_true = randint((n,), low=0, high=c, device=device, dtype="int64")
7377
if not one_hot_label:
74-
m_true = np.zeros((n, c), dtype="float32")
78+
m_true = np.zeros((n, c), dtype=dtype)
7579
for i in range(n):
7680
m_true[i, np_true[i]] = 1
7781
m_true = raf.array(m_true, device=device)
@@ -83,10 +87,12 @@ def forward(self, y_true, y_pred): # pylint: disable=no-self-use
8387
check(m_loss, t_loss)
8488
check(v_loss, t_loss)
8589
# backward
86-
m_dy, t_dy = randn_torch((), device=device)
90+
m_dy, t_dy = randn_torch((), device=device, dtype=dtype)
8791
t_loss.backward(t_dy)
8892
m_loss.backward(m_dy)
89-
check(m_pred.grad, t_pred.grad)
93+
rtol = 1e-5 if dtype == "float32" else 1e-3
94+
atol = 1e-5 if dtype == "float32" else 1e-3
95+
check(m_pred.grad, t_pred.grad, rtol=rtol, atol=atol)
9096

9197

9298
@pytest.mark.parametrize("device", ["cpu"])

0 commit comments

Comments
 (0)