Skip to content

Commit 6aed7a1

Browse files
committed
differentiate augmented and non-augmented gnns
1 parent 6a7aa53 commit 6aed7a1

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

chebifier/prediction_models/gnn_predictor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
7575
return model
7676

7777
def read_smiles(self, smiles):
78+
from chebai_graph.preprocessing.datasets.chebi import GraphPropAsPerNodeType
79+
7880
d = self.dataset.READER().to_data(dict(features=smiles, labels=None))
7981
property_data = d
8082
# TODO merge props into base should not be a method of a dataset (or at least static)
@@ -90,9 +92,13 @@ def read_smiles(self, smiles):
9092
if len(encoded_value.shape) == 3:
9193
encoded_value = encoded_value.squeeze(0)
9294
property_data[property.name] = encoded_value
93-
d["features"] = self.dataset._merge_props_into_base(
94-
property_data, max_len_node_properties=self.model.gnn.in_channels
95-
)
95+
# Augmented graphs need an additional argument
96+
if isinstance(self.dataset, GraphPropAsPerNodeType):
97+
d["features"] = self.dataset._merge_props_into_base(
98+
property_data, max_len_node_properties=self.model.gnn.in_channels
99+
)
100+
else:
101+
d["features"] = self.dataset._merge_props_into_base(property_data)
96102
return d
97103

98104

0 commit comments

Comments
 (0)