Skip to content

Commit 0444131

Browse files
committed
minor changes for gni
1 parent a5070e8 commit 0444131

File tree

2 files changed

+19
-27
lines changed

2 files changed

+19
-27
lines changed

chebai_graph/models/dynamic_gni.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
3333
)
3434
self.distribution = distribution
3535

36-
self.complete_randomness = config.get("complete_randomness", True)
36+
self.complete_randomness = (
37+
str(config.get("complete_randomness", "True")).lower() == "true"
38+
)
39+
40+
print("Using complete randomness: ", self.complete_randomness)
3741

3842
if not self.complete_randomness:
3943
assert (
@@ -44,11 +48,25 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
4448
if config.get("random_pad_node") is not None
4549
else None
4650
)
51+
if self.random_pad_node is not None:
52+
print(
53+
f"[Info] Node features will be padded with {self.random_pad_node} "
54+
f"new set of random features from distribution {self.distribution} "
55+
f"in each forward pass."
56+
)
57+
4758
self.random_pad_edge = (
4859
int(config["random_pad_edge"])
4960
if config.get("random_pad_edge") is not None
5061
else None
5162
)
63+
if self.random_pad_edge is not None:
64+
print(
65+
f"[Info] Edge features will be padded with {self.random_pad_edge} "
66+
f"new set of random features from distribution {self.distribution} "
67+
f"in each forward pass."
68+
)
69+
5270
assert (
5371
self.random_pad_node > 0 or self.random_pad_edge > 0
5472
), "'random_pad_node' or 'random_pad_edge' must be positive integers"

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -388,32 +388,6 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
388388

389389
return base_df[base_data[0].keys()].to_dict("records")
390390

391-
@property
392-
def processed_file_names_dict(self) -> dict:
393-
"""
394-
Returns a dictionary for the processed and tokenized data files.
395-
396-
Returns:
397-
dict: A dictionary mapping dataset keys to their respective file names.
398-
For example, {"data": "data.pt"}.
399-
"""
400-
if self.n_token_limit is not None:
401-
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
402-
403-
data_pt_filename = "data"
404-
if self.zero_pad_node:
405-
data_pt_filename += f"_zpn{self.zero_pad_node}"
406-
if self.zero_pad_edge:
407-
data_pt_filename += f"_zpe{self.zero_pad_edge}"
408-
if self.random_pad_node:
409-
data_pt_filename += f"_rpn{self.random_pad_node}"
410-
if self.random_pad_edge:
411-
data_pt_filename += f"_rpe{self.random_pad_edge}"
412-
if self.random_pad_node or self.random_pad_edge:
413-
data_pt_filename += f"_D{self.distribution}"
414-
415-
return {"data": data_pt_filename + ".pt"}
416-
417391

418392
class GraphPropAsPerNodeType(DataPropertiesSetter, ABC):
419393
def __init__(self, properties=None, transform=None, **kwargs):

0 commit comments

Comments
 (0)