Skip to content

Commit 6ef135b

Browse files
authored
Allow passing graph level attributes to graphviz
1 parent 5db3779 commit 6ef135b

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

pymc/model_graph.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def make_graph(
439439
figsize=None,
440440
dpi=300,
441441
node_formatters: NodeTypeFormatterMapping | None = None,
442+
graph_attr: dict[str, Any] | None = None,
442443
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
443444
):
444445
"""Make graphviz Digraph of PyMC model.
@@ -459,7 +460,7 @@ def make_graph(
459460
node_formatters = node_formatters or {}
460461
node_formatters = update_node_formatters(node_formatters)
461462

462-
graph = graphviz.Digraph(name)
463+
graph = graphviz.Digraph(name, graph_attr=graph_attr)
463464
for plate in plates:
464465
if plate.dim_info:
465466
# must be preceded by 'cluster' to get a box around it
@@ -676,6 +677,7 @@ def model_to_graphviz(
676677
figsize: tuple[int, int] | None = None,
677678
dpi: int = 300,
678679
node_formatters: NodeTypeFormatterMapping | None = None,
680+
graph_attr: dict[str, Any] | None = None,
679681
include_dim_lengths: bool = True,
680682
):
681683
"""Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +706,10 @@ def model_to_graphviz(
704706
the size of the saved figure.
705707
dpi : int, optional
706708
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
709+
graph_attr : dict, optional
710+
A dictionary of top-level layout attributes for graphviz
711+
Check out graphviz documentation for more information on available attributes
712+
https://graphviz.org/doc/info/attrs.html
707713
node_formatters : dict, optional
708714
A dictionary mapping node types to functions that return a dictionary of node attributes.
709715
Check out graphviz documentation for more information on available
@@ -773,6 +779,7 @@ def model_to_graphviz(
773779
save=save,
774780
figsize=figsize,
775781
dpi=dpi,
782+
graph_attr=graph_attr,
776783
node_formatters=node_formatters,
777784
create_plate_label=create_plate_label_with_dim_length
778785
if include_dim_lengths

0 commit comments

Comments
 (0)