6
6
7
7
8
8
class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
9
+ def make_graph_sage_config (self , graph_sage_config ):
10
+ GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config" : {}, "num_neighbors" : [25 , 10 ], "dropout" : 0.5 ,
11
+ "hidden_channels" : 256 }
12
+ final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
13
+ if graph_sage_config :
14
+ bad_keys = []
15
+ for key in graph_sage_config :
16
+ if key not in GRAPH_SAGE_DEFAULT_CONFIG :
17
+ bad_keys .append (key )
18
+ if len (bad_keys ) > 0 :
19
+ raise Exception (f"Argument graph_sage_config contains invalid keys { ', ' .join (bad_keys )} ." )
20
+
21
+ final_sage_config .update (graph_sage_config )
22
+ return final_sage_config
23
+
9
24
def train (
10
25
self ,
11
26
graph_name : str ,
@@ -15,13 +30,15 @@ def train(
15
30
relationship_types : List [str ],
16
31
target_node_label : str = None ,
17
32
node_labels : List [str ] = None ,
33
+ graph_sage_config = None
18
34
) -> "Series[Any]" : # noqa: F821
19
35
mlConfigMap = {
20
36
"featureProperties" : feature_properties ,
21
37
"targetProperty" : target_property ,
22
38
"job_type" : "train" ,
23
39
"nodeProperties" : feature_properties + [target_property ],
24
- "relationshipTypes" : relationship_types
40
+ "relationshipTypes" : relationship_types ,
41
+ "graph_sage_config" : self .make_graph_sage_config (graph_sage_config )
25
42
}
26
43
27
44
if target_node_label :
0 commit comments