@@ -57,8 +57,12 @@ def forward(self, y_true, y_pred): # pylint: disable=no-self-use
57
57
@pytest .mark .parametrize ("device" , get_testable_devices ())
58
58
@pytest .mark .parametrize ("n" , [3 , 7 ])
59
59
@pytest .mark .parametrize ("c" , [2 , 6 ])
60
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float16" ])
60
61
@pytest .mark .parametrize ("one_hot_label" , [True , False ])
61
- def test_nll_loss (device , n , c , one_hot_label ):
62
+ def test_nll_loss (device , n , c , dtype , one_hot_label ):
63
+ if device == "cpu" and dtype == "float16" :
64
+ pytest .skip ("PyTorch nll_loss does not support float16 when using CPU." )
65
+
62
66
class TestModel (raf .Model ):
63
67
def build (self ):
64
68
pass
@@ -68,10 +72,10 @@ def forward(self, y_true, y_pred): # pylint: disable=no-self-use
68
72
return raf .nll_loss (y_true = y_true , y_pred = y_pred )
69
73
70
74
model = TestModel ()
71
- m_pred , t_pred = randn_torch ((n , c ), device = device , requires_grad = True )
75
+ m_pred , t_pred = randn_torch ((n , c ), dtype = dtype , device = device , requires_grad = True )
72
76
m_true , np_true = randint ((n ,), low = 0 , high = c , device = device , dtype = "int64" )
73
77
if not one_hot_label :
74
- m_true = np .zeros ((n , c ), dtype = "float32" )
78
+ m_true = np .zeros ((n , c ), dtype = dtype )
75
79
for i in range (n ):
76
80
m_true [i , np_true [i ]] = 1
77
81
m_true = raf .array (m_true , device = device )
@@ -83,10 +87,12 @@ def forward(self, y_true, y_pred): # pylint: disable=no-self-use
83
87
check (m_loss , t_loss )
84
88
check (v_loss , t_loss )
85
89
# backward
86
- m_dy , t_dy = randn_torch ((), device = device )
90
+ m_dy , t_dy = randn_torch ((), device = device , dtype = dtype )
87
91
t_loss .backward (t_dy )
88
92
m_loss .backward (m_dy )
89
- check (m_pred .grad , t_pred .grad )
93
+ rtol = 1e-5 if dtype == "float32" else 1e-3
94
+ atol = 1e-5 if dtype == "float32" else 1e-3
95
+ check (m_pred .grad , t_pred .grad , rtol = rtol , atol = atol )
90
96
91
97
92
98
@pytest .mark .parametrize ("device" , ["cpu" ])
0 commit comments