14
14
15
15
16
16
def _fetch_node_dfs (
17
- gds : GraphDataScience , G : Graph , node_properties : list [str ], node_labels : list [str ]
17
+ gds : GraphDataScience , G : Graph , node_properties_by_label : dict [ str , list [str ] ], node_labels : list [str ]
18
18
) -> dict [str , pd .DataFrame ]:
19
19
return {
20
20
lbl : gds .graph .nodeProperties .stream (
21
- G , node_properties = node_properties , node_labels = [lbl ], separate_property_columns = True
21
+ G , node_properties = node_properties_by_label [ lbl ] , node_labels = [lbl ], separate_property_columns = True
22
22
)
23
23
for lbl in node_labels
24
24
}
@@ -79,24 +79,31 @@ def from_gds(
79
79
"""
80
80
node_properties_from_gds = G .node_properties ()
81
81
assert isinstance (node_properties_from_gds , pd .Series )
82
- actual_node_properties = list (chain .from_iterable (node_properties_from_gds .to_dict ().values ()))
82
+ actual_node_properties = node_properties_from_gds .to_dict ()
83
+ all_actual_node_properties = list (chain .from_iterable (actual_node_properties .values ()))
83
84
84
- if size_property is not None and size_property not in actual_node_properties :
85
- raise ValueError (f"There is no node property '{ size_property } ' in graph '{ G .name ()} '" )
85
+ if size_property is not None :
86
+ if size_property not in all_actual_node_properties :
87
+ raise ValueError (f"There is no node property '{ size_property } ' in graph '{ G .name ()} '" )
86
88
87
89
if additional_node_properties is None :
88
- additional_node_properties = actual_node_properties
90
+ node_properties_by_label = { k : set ( v ) for k , v in actual_node_properties . items ()}
89
91
else :
90
92
for prop in additional_node_properties :
91
- if prop not in actual_node_properties :
93
+ if prop not in all_actual_node_properties :
92
94
raise ValueError (f"There is no node property '{ prop } ' in graph '{ G .name ()} '" )
93
95
94
- node_properties = set ()
95
- if additional_node_properties is not None :
96
- node_properties .update (additional_node_properties )
96
+ node_properties_by_label = {}
97
+ for label , props in actual_node_properties .items ():
98
+ node_properties_by_label [label ] = {
99
+ prop for prop in actual_node_properties [label ] if prop in additional_node_properties
100
+ }
101
+
97
102
if size_property is not None :
98
- node_properties .add (size_property )
99
- node_properties = list (node_properties )
103
+ for label , props in node_properties_by_label .items ():
104
+ props .add (size_property )
105
+
106
+ node_properties_by_label = {k : list (v ) for k , v in node_properties_by_label .items ()}
100
107
101
108
node_count = G .node_count ()
102
109
if node_count > max_node_count :
@@ -112,13 +119,14 @@ def from_gds(
112
119
property_name = None
113
120
try :
114
121
# Since GDS does not allow us to only fetch node IDs, we add the degree property
115
- # as a temporary property to ensure that we have at least one property to fetch
116
- if len (actual_node_properties ) == 0 :
122
+ # as a temporary property to ensure that we have at least one property for each label to fetch
123
+ if sum ([ len (props ) == 0 for props in node_properties_by_label . values ()]) > 0 :
117
124
property_name = f"neo4j-viz_property_{ uuid4 ()} "
118
125
gds .degree .mutate (G_fetched , mutateProperty = property_name )
119
- node_properties = [property_name ]
126
+ for props in node_properties_by_label .values ():
127
+ props .append (property_name )
120
128
121
- node_dfs = _fetch_node_dfs (gds , G_fetched , node_properties , G_fetched .node_labels ())
129
+ node_dfs = _fetch_node_dfs (gds , G_fetched , node_properties_by_label , G_fetched .node_labels ())
122
130
if property_name is not None :
123
131
for df in node_dfs .values ():
124
132
df .drop (columns = [property_name ], inplace = True )
@@ -131,35 +139,35 @@ def from_gds(
131
139
gds .graph .nodeProperties .drop (G_fetched , node_properties = [property_name ])
132
140
133
141
for df in node_dfs .values ():
134
- df .rename (columns = {"nodeId" : "id" }, inplace = True )
135
142
if property_name is not None and property_name in df .columns :
136
143
df .drop (columns = [property_name ], inplace = True )
137
- rel_df .rename (columns = {"sourceNodeId" : "source" , "targetNodeId" : "target" }, inplace = True )
138
144
139
145
node_props_df = pd .concat (node_dfs .values (), ignore_index = True , axis = 0 ).drop_duplicates ()
140
146
if size_property is not None :
141
- if "size" in actual_node_properties and size_property != "size" :
147
+ if "size" in all_actual_node_properties and size_property != "size" :
142
148
node_props_df .rename (columns = {"size" : "__size" }, inplace = True )
143
149
node_props_df .rename (columns = {size_property : "size" }, inplace = True )
144
150
145
151
for lbl , df in node_dfs .items ():
146
- if "labels" in actual_node_properties :
152
+ if "labels" in all_actual_node_properties :
147
153
df .rename (columns = {"labels" : "__labels" }, inplace = True )
148
154
df ["labels" ] = lbl
149
155
150
- node_labels_df = pd .concat ([df [["id " , "labels" ]] for df in node_dfs .values ()], ignore_index = True , axis = 0 )
151
- node_labels_df = node_labels_df .groupby ("id " ).agg ({"labels" : list })
156
+ node_labels_df = pd .concat ([df [["nodeId " , "labels" ]] for df in node_dfs .values ()], ignore_index = True , axis = 0 )
157
+ node_labels_df = node_labels_df .groupby ("nodeId " ).agg ({"labels" : list })
152
158
153
- node_df = node_props_df .merge (node_labels_df , on = "id " )
159
+ node_df = node_props_df .merge (node_labels_df , on = "nodeId " )
154
160
155
- if "caption" not in actual_node_properties :
161
+ if "caption" not in all_actual_node_properties :
156
162
node_df ["caption" ] = node_df ["labels" ].astype (str )
157
163
158
164
if "caption" not in rel_df .columns :
159
165
rel_df ["caption" ] = rel_df ["relationshipType" ]
160
166
161
167
try :
162
- return _from_dfs (node_df , rel_df , node_radius_min_max = node_radius_min_max , rename_properties = {"__size" : "size" })
168
+ return _from_dfs (
169
+ node_df , rel_df , node_radius_min_max = node_radius_min_max , rename_properties = {"__size" : "size" }, dropna = True
170
+ )
163
171
except ValueError as e :
164
172
err_msg = str (e )
165
173
if "column" in err_msg :
0 commit comments