Skip to content

Commit 74d5639

Browse files
breakanalysisorazve
andcommitted
Expose graphsage training configuration
Co-authored-by: Olga Razvenskaia <[email protected]>
1 parent 8e822be commit 74d5639

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

graphdatascience/gnn/gnn_nc_runner.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@
66

77

88
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
9+
def make_graph_sage_config(self, graph_sage_config):
10+
GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5,
11+
"hidden_channels": 256}
12+
final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
13+
if graph_sage_config:
14+
bad_keys = []
15+
for key in graph_sage_config:
16+
if key not in GRAPH_SAGE_DEFAULT_CONFIG:
17+
bad_keys.append(key)
18+
if len(bad_keys) > 0:
19+
raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.")
20+
21+
final_sage_config.update(graph_sage_config)
22+
return final_sage_config
23+
924
def train(
1025
self,
1126
graph_name: str,
@@ -15,13 +30,15 @@ def train(
1530
relationship_types: List[str],
1631
target_node_label: str = None,
1732
node_labels: List[str] = None,
33+
graph_sage_config = None
1834
) -> "Series[Any]": # noqa: F821
1935
mlConfigMap = {
2036
"featureProperties": feature_properties,
2137
"targetProperty": target_property,
2238
"job_type": "train",
2339
"nodeProperties": feature_properties + [target_property],
24-
"relationshipTypes": relationship_types
40+
"relationshipTypes": relationship_types,
41+
"graph_sage_config": self.make_graph_sage_config(graph_sage_config)
2542
}
2643

2744
if target_node_label:

0 commit comments

Comments
 (0)