@@ -235,7 +235,12 @@ def __repr__(self) -> str:
235
235
return f"{ self .__class__ .__name__ } ({ self ._graph_info (yields = yield_fields ).to_dict ()} )"
236
236
237
237
def visualize (
238
- self , node_count : int = 100 , center_nodes : Optional [List [int ]] = None , include_node_properties : List [str ] = None
238
+ self ,
239
+ notebook : bool = True ,
240
+ node_count : int = 100 ,
241
+ center_nodes : Optional [List [int ]] = None ,
242
+ include_node_properties : List [str ] = None ,
243
+ color_property : Optional [str ] = None ,
239
244
) -> Any :
240
245
visual_graph = self ._name
241
246
if self .node_count () > node_count :
@@ -262,6 +267,12 @@ def visualize(
262
267
if include_node_properties is not None :
263
268
node_properties .extend (include_node_properties )
264
269
270
+ if color_property is not None :
271
+ node_properties .append (color_property )
272
+
273
+ # Remove possible duplicates
274
+ node_properties = list (set (node_properties ))
275
+
265
276
result = self ._query_runner .call_procedure (
266
277
endpoint = "gds.graph.nodeProperties.stream" ,
267
278
params = CallParameters (
@@ -307,15 +318,20 @@ def visualize(
307
318
from pyvis .network import Network
308
319
309
320
net = Network (
310
- notebook = True ,
311
- cdn_resources = "remote" ,
321
+ notebook = True if notebook else False ,
322
+ cdn_resources = "remote" if notebook else "local" ,
312
323
bgcolor = "#222222" , # Dark background
313
324
font_color = "white" ,
314
325
height = "750px" , # Modify according to your screen size
315
326
width = "100%" ,
316
327
)
317
328
318
- label_to_color = {label : self ._random_bright_color () for label in self .node_labels ()}
329
+ if color_property is None :
330
+ color_map = {label : self ._random_bright_color () for label in self .node_labels ()}
331
+ else :
332
+ color_map = {
333
+ prop_val : self ._random_bright_color () for prop_val in node_properties_df [color_property ].unique ()
334
+ }
319
335
320
336
for _ , node in node_properties_df .iterrows ():
321
337
title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
@@ -324,10 +340,15 @@ def visualize(
324
340
for prop in include_node_properties :
325
341
title += f"\n { prop } = { node [prop ]} "
326
342
343
+ if color_property is None :
344
+ color = color_map [node ["nodeLabels" ][0 ]]
345
+ else :
346
+ color = color_map [node [color_property ]]
347
+
327
348
net .add_node (
328
349
int (node ["nodeId" ]),
329
350
value = node [pr_prop ],
330
- color = label_to_color [ node [ "nodeLabels" ][ 0 ]] ,
351
+ color = color ,
331
352
title = title ,
332
353
)
333
354
0 commit comments