Skip to content

Commit 09b33f1

Browse files
author
Han Wang
committed
restructure the test folders. add test_common.
1 parent fa03351 commit 09b33f1

File tree

5 files changed

+28
-1
lines changed

5 files changed

+28
-1
lines changed

deepmd/pt_expt/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ def to_torch_array(array: Any) -> torch.Tensor | None:
3131
if array is None:
3232
return None
3333
if torch.is_tensor(array):
34-
return array
34+
return array.to(device=env.DEVICE)
3535
return torch.as_tensor(array, device=env.DEVICE)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
4+
import numpy as np
5+
6+
from deepmd.pt_expt.common import (
7+
to_torch_array,
8+
)
9+
from deepmd.pt_expt.utils import (
10+
env,
11+
)
12+
13+
torch = importlib.import_module("torch")
14+
15+
16+
def test_to_torch_array_moves_device() -> None:
17+
arr = np.arange(6, dtype=np.float32).reshape(2, 3)
18+
tensor = to_torch_array(arr)
19+
assert torch.is_tensor(tensor)
20+
assert tensor.device == env.DEVICE
21+
22+
input_tensor = torch.as_tensor(arr, device=torch.device("cpu"))
23+
output_tensor = to_torch_array(input_tensor)
24+
assert torch.is_tensor(output_tensor)
25+
assert output_tensor.device == env.DEVICE

0 commit comments

Comments
 (0)