Skip to content

Commit 073d12c

Browse files
author
Han Wang
committed
fix bug of device
1 parent 6ea965b commit 073d12c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

source/tests/pt_expt/utils/test_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_cross_backend_consistency(self) -> None:
165165
# Test forward pass
166166
rng = np.random.default_rng()
167167
x_np = rng.standard_normal((5, self.in_dim))
168-
x_torch = torch.from_numpy(x_np)
168+
x_torch = torch.from_numpy(x_np).to(env.DEVICE)
169169

170170
out_dp = dp_net.call(x_np)
171171
out_pt = pt_net(x_torch).detach().cpu().numpy()

0 commit comments

Comments
 (0)