Skip to content

Commit bc3981b

Browse files
committed
set device for random initiated tensors
1 parent e304f15 commit bc3981b

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

chebai_graph/models/dynamic_gni.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
3737
edge_dim=self.edge_dim,
3838
act=self.activation,
3939
)
40+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4041

4142
def forward(self, batch: dict[str, Any]) -> Tensor:
4243
"""
@@ -51,10 +52,15 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
5152
graph_data = batch["features"][0]
5253
assert isinstance(graph_data, GraphData), "Expected GraphData instance"
5354

54-
random_x = torch.empty(graph_data.x.shape[0], graph_data.x.shape[1])
55+
random_x = torch.empty(
56+
graph_data.x.shape[0], graph_data.x.shape[1], device=self.device
57+
)
5558
RandomFeatureInitializationReader.random_gni(random_x, self.distribution)
59+
5660
random_edge_attr = torch.empty(
57-
graph_data.edge_attr.shape[0], graph_data.edge_attr.shape[1]
61+
graph_data.edge_attr.shape[0],
62+
graph_data.edge_attr.shape[1],
63+
device=self.device,
5864
)
5965
RandomFeatureInitializationReader.random_gni(
6066
random_edge_attr, self.distribution

0 commit comments

Comments
 (0)