Skip to content

Commit 5575cf3

Browse files
committed
Add node coloring by label for G.visualize
1 parent 3982faa commit 5575cf3

File tree

1 file changed

+34
-8
lines changed

1 file changed

+34
-8
lines changed

graphdatascience/graph/graph_object.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import colorsys
4+
import random
35
from types import TracebackType
46
from typing import Any, List, Optional, Type, Union
57
from uuid import uuid4
@@ -232,16 +234,18 @@ def __repr__(self) -> str:
232234
]
233235
return f"{self.__class__.__name__}({self._graph_info(yields=yield_fields).to_dict()})"
234236

235-
def visualize(self, node_count: int = 100):
237+
def visualize(self, node_count: int = 100, center_nodes: Optional[List[int]] = None) -> Any:
236238
visual_graph = self._name
237239
if self.node_count() > node_count:
238-
ratio = float(node_count) / self.node_count()
239240
visual_graph = str(uuid4())
241+
config = dict(samplingRatio=float(node_count) / self.node_count())
242+
243+
if center_nodes is not None:
244+
config["startNodes"] = center_nodes
245+
240246
self._query_runner.call_procedure(
241247
endpoint="gds.graph.sample.rwr",
242-
params=CallParameters(
243-
graph_name=visual_graph, fromGraphName=self._name, config=dict(samplingRatio=ratio)
244-
),
248+
params=CallParameters(graph_name=visual_graph, fromGraphName=self._name, config=config),
245249
custom_error=False,
246250
)
247251

@@ -254,13 +258,22 @@ def visualize(self, node_count: int = 100):
254258

255259
result = self._query_runner.call_procedure(
256260
endpoint="gds.graph.nodeProperties.stream",
257-
params=CallParameters(graph_name=visual_graph, properties=[pr_prop]),
261+
params=CallParameters(
262+
graph_name=visual_graph,
263+
properties=[pr_prop],
264+
nodeLabels=self.node_labels(),
265+
config=dict(listNodeLabels=True),
266+
),
258267
custom_error=False,
259268
)
260269

261270
# new format was requested, but the query was run via Cypher
262271
if "propertyValue" in result.keys():
263272
wide_result = result.pivot(index=["nodeId"], columns=["nodeProperty"], values="propertyValue")
273+
# nodeLabels cannot be an index column of the pivot as its not hashable
274+
# so we need to manually join it back in
275+
labels_df = result[["nodeId", "nodeLabels"]].set_index("nodeId")
276+
wide_result = wide_result.join(labels_df, on="nodeId")
264277
result = wide_result.reset_index()
265278
result.columns.name = None
266279
node_properties_df = result
@@ -271,6 +284,7 @@ def visualize(self, node_count: int = 100):
271284
custom_error=False,
272285
)
273286

287+
# Clean up
274288
if visual_graph != self._name:
275289
self._query_runner.call_procedure(
276290
endpoint="gds.graph.drop",
@@ -289,16 +303,28 @@ def visualize(self, node_count: int = 100):
289303
net = Network(
290304
notebook=True,
291305
cdn_resources="remote",
292-
bgcolor="#222222",
306+
bgcolor="#222222", # Dark background
293307
font_color="white",
294308
height="750px", # Modify according to your screen size
295309
width="100%",
296310
)
297311

312+
label_to_color = {label: self._random_bright_color() for label in self.node_labels()}
313+
298314
for _, node in node_properties_df.iterrows():
299-
net.add_node(int(node["nodeId"]), value=node[pr_prop])
315+
net.add_node(
316+
int(node["nodeId"]),
317+
value=node[pr_prop],
318+
color=label_to_color[node["nodeLabels"][0]],
319+
title=str(node["nodeId"]),
320+
)
300321

301322
# Add all the relationships
302323
net.add_edges(zip(relationships_df["sourceNodeId"], relationships_df["targetNodeId"]))
303324

304325
return net.show(f"{self._name}.html")
326+
327+
@staticmethod
328+
def _random_bright_color() -> str:
329+
h = random.randint(0, 255) / 255.0
330+
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(h, 0.7, 1.0)))

0 commit comments

Comments
 (0)