We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6ea965b commit 073d12cCopy full SHA for 073d12c
source/tests/pt_expt/utils/test_network.py
@@ -165,7 +165,7 @@ def test_cross_backend_consistency(self) -> None:
165
# Test forward pass
166
rng = np.random.default_rng()
167
x_np = rng.standard_normal((5, self.in_dim))
168
- x_torch = torch.from_numpy(x_np)
+ x_torch = torch.from_numpy(x_np).to(env.DEVICE)
169
170
out_dp = dp_net.call(x_np)
171
out_pt = pt_net(x_torch).detach().cpu().numpy()
0 commit comments