1
1
from __future__ import annotations
2
2
3
+ import colorsys
4
+ import random
3
5
from types import TracebackType
4
6
from typing import Any , List , Optional , Type , Union
5
7
from uuid import uuid4
@@ -232,16 +234,18 @@ def __repr__(self) -> str:
232
234
]
233
235
return f"{ self .__class__ .__name__ } ({ self ._graph_info (yields = yield_fields ).to_dict ()} )"
234
236
235
- def visualize (self , node_count : int = 100 ) :
237
+ def visualize (self , node_count : int = 100 , center_nodes : Optional [ List [ int ]] = None ) -> Any :
236
238
visual_graph = self ._name
237
239
if self .node_count () > node_count :
238
- ratio = float (node_count ) / self .node_count ()
239
240
visual_graph = str (uuid4 ())
241
+ config = dict (samplingRatio = float (node_count ) / self .node_count ())
242
+
243
+ if center_nodes is not None :
244
+ config ["startNodes" ] = center_nodes
245
+
240
246
self ._query_runner .call_procedure (
241
247
endpoint = "gds.graph.sample.rwr" ,
242
- params = CallParameters (
243
- graph_name = visual_graph , fromGraphName = self ._name , config = dict (samplingRatio = ratio )
244
- ),
248
+ params = CallParameters (graph_name = visual_graph , fromGraphName = self ._name , config = config ),
245
249
custom_error = False ,
246
250
)
247
251
@@ -254,13 +258,22 @@ def visualize(self, node_count: int = 100):
254
258
255
259
result = self ._query_runner .call_procedure (
256
260
endpoint = "gds.graph.nodeProperties.stream" ,
257
- params = CallParameters (graph_name = visual_graph , properties = [pr_prop ]),
261
+ params = CallParameters (
262
+ graph_name = visual_graph ,
263
+ properties = [pr_prop ],
264
+ nodeLabels = self .node_labels (),
265
+ config = dict (listNodeLabels = True ),
266
+ ),
258
267
custom_error = False ,
259
268
)
260
269
261
270
# new format was requested, but the query was run via Cypher
262
271
if "propertyValue" in result .keys ():
263
272
wide_result = result .pivot (index = ["nodeId" ], columns = ["nodeProperty" ], values = "propertyValue" )
273
+ # nodeLabels cannot be an index column of the pivot as its not hashable
274
+ # so we need to manually join it back in
275
+ labels_df = result [["nodeId" , "nodeLabels" ]].set_index ("nodeId" )
276
+ wide_result = wide_result .join (labels_df , on = "nodeId" )
264
277
result = wide_result .reset_index ()
265
278
result .columns .name = None
266
279
node_properties_df = result
@@ -271,6 +284,7 @@ def visualize(self, node_count: int = 100):
271
284
custom_error = False ,
272
285
)
273
286
287
+ # Clean up
274
288
if visual_graph != self ._name :
275
289
self ._query_runner .call_procedure (
276
290
endpoint = "gds.graph.drop" ,
@@ -289,16 +303,28 @@ def visualize(self, node_count: int = 100):
289
303
net = Network (
290
304
notebook = True ,
291
305
cdn_resources = "remote" ,
292
- bgcolor = "#222222" ,
306
+ bgcolor = "#222222" , # Dark background
293
307
font_color = "white" ,
294
308
height = "750px" , # Modify according to your screen size
295
309
width = "100%" ,
296
310
)
297
311
312
+ label_to_color = {label : self ._random_bright_color () for label in self .node_labels ()}
313
+
298
314
for _ , node in node_properties_df .iterrows ():
299
- net .add_node (int (node ["nodeId" ]), value = node [pr_prop ])
315
+ net .add_node (
316
+ int (node ["nodeId" ]),
317
+ value = node [pr_prop ],
318
+ color = label_to_color [node ["nodeLabels" ][0 ]],
319
+ title = str (node ["nodeId" ]),
320
+ )
300
321
301
322
# Add all the relationships
302
323
net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
303
324
304
325
return net .show (f"{ self ._name } .html" )
326
+
327
+ @staticmethod
328
+ def _random_bright_color () -> str :
329
+ h = random .randint (0 , 255 ) / 255.0
330
+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
0 commit comments