Skip to content

Commit 41f5f1f

Browse files
committed
Add more features to G.visualize
1 parent 1498df5 commit 41f5f1f

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

graphdatascience/graph/graph_object.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,12 @@ def __repr__(self) -> str:
235235
return f"{self.__class__.__name__}({self._graph_info(yields=yield_fields).to_dict()})"
236236

237237
def visualize(
238-
self, node_count: int = 100, center_nodes: Optional[List[int]] = None, include_node_properties: List[str] = None
238+
self,
239+
notebook: bool = True,
240+
node_count: int = 100,
241+
center_nodes: Optional[List[int]] = None,
242+
include_node_properties: List[str] = None,
243+
color_property: Optional[str] = None,
239244
) -> Any:
240245
visual_graph = self._name
241246
if self.node_count() > node_count:
@@ -262,6 +267,12 @@ def visualize(
262267
if include_node_properties is not None:
263268
node_properties.extend(include_node_properties)
264269

270+
if color_property is not None:
271+
node_properties.append(color_property)
272+
273+
# Remove possible duplicates
274+
node_properties = list(set(node_properties))
275+
265276
result = self._query_runner.call_procedure(
266277
endpoint="gds.graph.nodeProperties.stream",
267278
params=CallParameters(
@@ -307,15 +318,20 @@ def visualize(
307318
from pyvis.network import Network
308319

309320
net = Network(
310-
notebook=True,
311-
cdn_resources="remote",
321+
notebook=True if notebook else False,
322+
cdn_resources="remote" if notebook else "local",
312323
bgcolor="#222222", # Dark background
313324
font_color="white",
314325
height="750px", # Modify according to your screen size
315326
width="100%",
316327
)
317328

318-
label_to_color = {label: self._random_bright_color() for label in self.node_labels()}
329+
if color_property is None:
330+
color_map = {label: self._random_bright_color() for label in self.node_labels()}
331+
else:
332+
color_map = {
333+
prop_val: self._random_bright_color() for prop_val in node_properties_df[color_property].unique()
334+
}
319335

320336
for _, node in node_properties_df.iterrows():
321337
title = f"Node ID: {node['nodeId']}\nLabels: {node['nodeLabels']}"
@@ -324,10 +340,15 @@ def visualize(
324340
for prop in include_node_properties:
325341
title += f"\n{prop} = {node[prop]}"
326342

343+
if color_property is None:
344+
color = color_map[node["nodeLabels"][0]]
345+
else:
346+
color = color_map[node[color_property]]
347+
327348
net.add_node(
328349
int(node["nodeId"]),
329350
value=node[pr_prop],
330-
color=label_to_color[node["nodeLabels"][0]],
351+
color=color,
331352
title=title,
332353
)
333354

0 commit comments

Comments
 (0)