Skip to content

Commit d121e3c

Browse files
committed
Even more features for G.visualize
1 parent 41f5f1f commit d121e3c

File tree

1 file changed

+69
-29
lines changed

1 file changed

+69
-29
lines changed

graphdatascience/graph/graph_object.py

+69-29
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def node_count(self) -> int:
8282
"""
8383
Returns:
8484
the number of nodes in the graph
85-
8685
"""
8786
return self._graph_info(["nodeCount"]) # type: ignore
8887

@@ -191,7 +190,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191190
192191
Returns:
193192
the result of the drop operation
194-
195193
"""
196194
result = self._query_runner.call_procedure(
197195
endpoint="gds.graph.drop",
@@ -205,7 +203,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205203
"""
206204
Returns:
207205
the creation time of the graph
208-
209206
"""
210207
return self._graph_info(["creationTime"])
211208

@@ -236,12 +233,36 @@ def __repr__(self) -> str:
236233

237234
def visualize(
238235
self,
239-
notebook: bool = True,
240236
node_count: int = 100,
237+
directed: bool = True,
241238
center_nodes: Optional[List[int]] = None,
242-
include_node_properties: List[str] = None,
239+
include_node_properties: Optional[List[str]] = None,
243240
color_property: Optional[str] = None,
241+
size_property: Optional[str] = None,
242+
rel_weight_property: Optional[str] = None,
243+
notebook: bool = True,
244+
px_height: int = 750,
245+
theme: str = "dark",
244246
) -> Any:
247+
"""
248+
Visualize the `Graph` in an interactive graphical interface.
249+
The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
250+
251+
Args:
252+
node_count: number of nodes in the graph to be visualized
253+
directed: whether or not to display relationships as directed
254+
center_nodes: nodes around subgraph will be sampled, if sampling is necessary
255+
include_node_properties: node properties to include for mouse-over inspection
256+
color_property: node property that determines node categories for coloring. Default is to use node labels
257+
size_property: node property that determines the size of nodes. Default is to compute a page rank for this
258+
rel_weight_property: relationship property that determines width of relationships
259+
notebook: whether or not the code is run in a notebook
260+
px_height: the height of the graphic containing output the visualization
261+
theme: coloring theme for the visualization. "light" or "dark"
262+
263+
Returns:
264+
an interactive graphical visualization of the specified graph
265+
"""
245266
visual_graph = self._name
246267
if self.node_count() > node_count:
247268
visual_graph = str(uuid4())
@@ -256,14 +277,19 @@ def visualize(
256277
custom_error=False,
257278
)
258279

259-
pr_prop = str(uuid4())
260-
self._query_runner.call_procedure(
261-
endpoint="gds.pageRank.mutate",
262-
params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=pr_prop)),
263-
custom_error=False,
264-
)
280+
# Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
281+
if size_property is None:
282+
size_property = str(uuid4())
283+
self._query_runner.call_procedure(
284+
endpoint="gds.pageRank.mutate",
285+
params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=size_property)),
286+
custom_error=False,
287+
)
288+
clean_up_size_prop = True
289+
else:
290+
clean_up_size_prop = False
265291

266-
node_properties = [pr_prop]
292+
node_properties = [size_property]
267293
if include_node_properties is not None:
268294
node_properties.extend(include_node_properties)
269295

@@ -295,11 +321,18 @@ def visualize(
295321
result.columns.name = None
296322
node_properties_df = result
297323

298-
relationships_df = self._query_runner.call_procedure(
299-
endpoint="gds.graph.relationships.stream",
300-
params=CallParameters(graph_name=visual_graph),
301-
custom_error=False,
302-
)
324+
if rel_weight_property is None:
325+
relationships_df = self._query_runner.call_procedure(
326+
endpoint="gds.graph.relationships.stream",
327+
params=CallParameters(graph_name=visual_graph),
328+
custom_error=False,
329+
)
330+
else:
331+
relationships_df = self._query_runner.call_procedure(
332+
endpoint="gds.graph.relationshipProperty.stream",
333+
params=CallParameters(graph_name=visual_graph, properties=rel_weight_property),
334+
custom_error=False,
335+
)
303336

304337
# Clean up
305338
if visual_graph != self._name:
@@ -308,10 +341,10 @@ def visualize(
308341
params=CallParameters(graph_name=visual_graph),
309342
custom_error=False,
310343
)
311-
else:
344+
elif clean_up_size_prop:
312345
self._query_runner.call_procedure(
313346
endpoint="gds.graph.nodeProperties.drop",
314-
params=CallParameters(graph_name=visual_graph, nodeProperties=pr_prop),
347+
params=CallParameters(graph_name=visual_graph, nodeProperties=size_property),
315348
custom_error=False,
316349
)
317350

@@ -320,19 +353,21 @@ def visualize(
320353
net = Network(
321354
notebook=True if notebook else False,
322355
cdn_resources="remote" if notebook else "local",
323-
bgcolor="#222222", # Dark background
324-
font_color="white",
325-
height="750px", # Modify according to your screen size
356+
directed=directed,
357+
bgcolor="#222222" if theme == "dark" else "#FDFDFD",
358+
font_color="white" if theme == "dark" else "black",
359+
height=f"{px_height}px",
326360
width="100%",
327361
)
328362

329363
if color_property is None:
330-
color_map = {label: self._random_bright_color() for label in self.node_labels()}
364+
color_map = {label: self._random_bright_color(theme) for label in self.node_labels()}
331365
else:
332366
color_map = {
333-
prop_val: self._random_bright_color() for prop_val in node_properties_df[color_property].unique()
367+
prop_val: self._random_bright_color(theme) for prop_val in node_properties_df[color_property].unique()
334368
}
335369

370+
# Add all the nodes
336371
for _, node in node_properties_df.iterrows():
337372
title = f"Node ID: {node['nodeId']}\nLabels: {node['nodeLabels']}"
338373
if include_node_properties is not None:
@@ -347,17 +382,22 @@ def visualize(
347382

348383
net.add_node(
349384
int(node["nodeId"]),
350-
value=node[pr_prop],
385+
value=node[size_property],
351386
color=color,
352387
title=title,
353388
)
354389

355390
# Add all the relationships
356-
net.add_edges(zip(relationships_df["sourceNodeId"], relationships_df["targetNodeId"]))
391+
for _, rel in relationships_df.iterrows():
392+
if rel_weight_property is None:
393+
net.add_edge(rel["sourceNodeId"], rel["targetNodeId"], title=f"Type: {rel['relationshipType']}")
394+
else:
395+
title = f"Type: {rel['relationshipType']}\n{rel_weight_property} = {rel['rel_weight_property']}"
396+
net.add_edge(rel["sourceNodeId"], rel["targetNodeId"], title=title, value=rel[rel_weight_property])
357397

358398
return net.show(f"{self._name}.html")
359399

360400
@staticmethod
361-
def _random_bright_color() -> str:
362-
h = random.randint(0, 255) / 255.0
363-
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(h, 0.7, 1.0)))
401+
def _random_bright_color(theme) -> str:
402+
l = 0.7 if theme == "dark" else 0.4
403+
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(random.random(), l, 1.0)))

0 commit comments

Comments
 (0)