File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -29,3 +29,8 @@ def test_scalar_to_float(self) -> None:
2929
3030 valid_ndarray = np .array ([[[float_x ]]])
3131 self .assertAlmostEqual (scalar_to_float (valid_ndarray ), float_x )
32+
33+ def test_scalar_to_float_bf16 (self ) -> None :
34+ float_x = 3.45
35+ valid_tensor = torch .Tensor ([float_x ]).to (torch .bfloat16 )
36+ self .assertAlmostEqual (scalar_to_float (valid_tensor ), float_x , delta = 0.01 )
Original file line number Diff line number Diff line change @@ -20,7 +20,7 @@ def scalar_to_float(scalar: Scalar) -> float:
2020 f"Scalar tensor must contain a single item, { numel } given."
2121 )
2222
23- return float (scalar .cpu ().detach ().numpy ().item ())
23+ return float (scalar .cpu ().detach ().float (). numpy ().item ())
2424 elif isinstance (scalar , ndarray ):
2525 numel = scalar .size
2626 if numel != 1 :
You can’t perform that action at this time.
0 commit comments