@@ -82,7 +82,6 @@ def node_count(self) -> int:
82
82
"""
83
83
Returns:
84
84
the number of nodes in the graph
85
-
86
85
"""
87
86
return self ._graph_info (["nodeCount" ]) # type: ignore
88
87
@@ -191,7 +190,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191
190
192
191
Returns:
193
192
the result of the drop operation
194
-
195
193
"""
196
194
result = self ._query_runner .call_procedure (
197
195
endpoint = "gds.graph.drop" ,
@@ -205,7 +203,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205
203
"""
206
204
Returns:
207
205
the creation time of the graph
208
-
209
206
"""
210
207
return self ._graph_info (["creationTime" ])
211
208
@@ -236,12 +233,36 @@ def __repr__(self) -> str:
236
233
237
234
def visualize (
238
235
self ,
239
- notebook : bool = True ,
240
236
node_count : int = 100 ,
237
+ directed : bool = True ,
241
238
center_nodes : Optional [List [int ]] = None ,
242
- include_node_properties : List [str ] = None ,
239
+ include_node_properties : Optional [ List [str ] ] = None ,
243
240
color_property : Optional [str ] = None ,
241
+ size_property : Optional [str ] = None ,
242
+ rel_weight_property : Optional [str ] = None ,
243
+ notebook : bool = True ,
244
+ px_height : int = 750 ,
245
+ theme : str = "dark" ,
244
246
) -> Any :
247
+ """
248
+ Visualize the `Graph` in an interactive graphical interface.
249
+ The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
250
+
251
+ Args:
252
+ node_count: number of nodes in the graph to be visualized
253
+ directed: whether or not to display relationships as directed
254
+ center_nodes: nodes around subgraph will be sampled, if sampling is necessary
255
+ include_node_properties: node properties to include for mouse-over inspection
256
+ color_property: node property that determines node categories for coloring. Default is to use node labels
257
+ size_property: node property that determines the size of nodes. Default is to compute a page rank for this
258
+ rel_weight_property: relationship property that determines width of relationships
259
+ notebook: whether or not the code is run in a notebook
260
+ px_height: the height of the graphic containing output the visualization
261
+ theme: coloring theme for the visualization. "light" or "dark"
262
+
263
+ Returns:
264
+ an interactive graphical visualization of the specified graph
265
+ """
245
266
visual_graph = self ._name
246
267
if self .node_count () > node_count :
247
268
visual_graph = str (uuid4 ())
@@ -256,14 +277,19 @@ def visualize(
256
277
custom_error = False ,
257
278
)
258
279
259
- pr_prop = str (uuid4 ())
260
- self ._query_runner .call_procedure (
261
- endpoint = "gds.pageRank.mutate" ,
262
- params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = pr_prop )),
263
- custom_error = False ,
264
- )
280
+ # Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
281
+ if size_property is None :
282
+ size_property = str (uuid4 ())
283
+ self ._query_runner .call_procedure (
284
+ endpoint = "gds.pageRank.mutate" ,
285
+ params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = size_property )),
286
+ custom_error = False ,
287
+ )
288
+ clean_up_size_prop = True
289
+ else :
290
+ clean_up_size_prop = False
265
291
266
- node_properties = [pr_prop ]
292
+ node_properties = [size_property ]
267
293
if include_node_properties is not None :
268
294
node_properties .extend (include_node_properties )
269
295
@@ -295,11 +321,18 @@ def visualize(
295
321
result .columns .name = None
296
322
node_properties_df = result
297
323
298
- relationships_df = self ._query_runner .call_procedure (
299
- endpoint = "gds.graph.relationships.stream" ,
300
- params = CallParameters (graph_name = visual_graph ),
301
- custom_error = False ,
302
- )
324
+ if rel_weight_property is None :
325
+ relationships_df = self ._query_runner .call_procedure (
326
+ endpoint = "gds.graph.relationships.stream" ,
327
+ params = CallParameters (graph_name = visual_graph ),
328
+ custom_error = False ,
329
+ )
330
+ else :
331
+ relationships_df = self ._query_runner .call_procedure (
332
+ endpoint = "gds.graph.relationshipProperty.stream" ,
333
+ params = CallParameters (graph_name = visual_graph , properties = rel_weight_property ),
334
+ custom_error = False ,
335
+ )
303
336
304
337
# Clean up
305
338
if visual_graph != self ._name :
@@ -308,10 +341,10 @@ def visualize(
308
341
params = CallParameters (graph_name = visual_graph ),
309
342
custom_error = False ,
310
343
)
311
- else :
344
+ elif clean_up_size_prop :
312
345
self ._query_runner .call_procedure (
313
346
endpoint = "gds.graph.nodeProperties.drop" ,
314
- params = CallParameters (graph_name = visual_graph , nodeProperties = pr_prop ),
347
+ params = CallParameters (graph_name = visual_graph , nodeProperties = size_property ),
315
348
custom_error = False ,
316
349
)
317
350
@@ -320,19 +353,21 @@ def visualize(
320
353
net = Network (
321
354
notebook = True if notebook else False ,
322
355
cdn_resources = "remote" if notebook else "local" ,
323
- bgcolor = "#222222" , # Dark background
324
- font_color = "white" ,
325
- height = "750px" , # Modify according to your screen size
356
+ directed = directed ,
357
+ bgcolor = "#222222" if theme == "dark" else "#FDFDFD" ,
358
+ font_color = "white" if theme == "dark" else "black" ,
359
+ height = f"{ px_height } px" ,
326
360
width = "100%" ,
327
361
)
328
362
329
363
if color_property is None :
330
- color_map = {label : self ._random_bright_color () for label in self .node_labels ()}
364
+ color_map = {label : self ._random_bright_color (theme ) for label in self .node_labels ()}
331
365
else :
332
366
color_map = {
333
- prop_val : self ._random_bright_color () for prop_val in node_properties_df [color_property ].unique ()
367
+ prop_val : self ._random_bright_color (theme ) for prop_val in node_properties_df [color_property ].unique ()
334
368
}
335
369
370
+ # Add all the nodes
336
371
for _ , node in node_properties_df .iterrows ():
337
372
title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
338
373
if include_node_properties is not None :
@@ -347,17 +382,22 @@ def visualize(
347
382
348
383
net .add_node (
349
384
int (node ["nodeId" ]),
350
- value = node [pr_prop ],
385
+ value = node [size_property ],
351
386
color = color ,
352
387
title = title ,
353
388
)
354
389
355
390
# Add all the relationships
356
- net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
391
+ for _ , rel in relationships_df .iterrows ():
392
+ if rel_weight_property is None :
393
+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = f"Type: { rel ['relationshipType' ]} " )
394
+ else :
395
+ title = f"Type: { rel ['relationshipType' ]} \n { rel_weight_property } = { rel ['rel_weight_property' ]} "
396
+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = title , value = rel [rel_weight_property ])
357
397
358
398
return net .show (f"{ self ._name } .html" )
359
399
360
400
@staticmethod
361
- def _random_bright_color () -> str :
362
- h = random . randint ( 0 , 255 ) / 255.0
363
- return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
401
+ def _random_bright_color (theme ) -> str :
402
+ l = 0.7 if theme == "dark" else 0.4
403
+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (random . random (), l , 1.0 )))
0 commit comments