@@ -439,6 +439,7 @@ def make_graph(
439
439
figsize = None ,
440
440
dpi = 300 ,
441
441
node_formatters : NodeTypeFormatterMapping | None = None ,
442
+ graph_attr : dict [str , Any ] | None = None ,
442
443
create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
443
444
):
444
445
"""Make graphviz Digraph of PyMC model.
@@ -459,7 +460,7 @@ def make_graph(
459
460
node_formatters = node_formatters or {}
460
461
node_formatters = update_node_formatters (node_formatters )
461
462
462
- graph = graphviz .Digraph (name )
463
+ graph = graphviz .Digraph (name , graph_attr = graph_attr )
463
464
for plate in plates :
464
465
if plate .dim_info :
465
466
# must be preceded by 'cluster' to get a box around it
@@ -676,6 +677,7 @@ def model_to_graphviz(
676
677
figsize : tuple [int , int ] | None = None ,
677
678
dpi : int = 300 ,
678
679
node_formatters : NodeTypeFormatterMapping | None = None ,
680
+ graph_attr : dict [str , Any ] | None = None ,
679
681
include_dim_lengths : bool = True ,
680
682
):
681
683
"""Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +706,10 @@ def model_to_graphviz(
704
706
the size of the saved figure.
705
707
dpi : int, optional
706
708
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
709
+ graph_attr : dict, optional
710
+ A dictionary of top-level layout attributes for graphviz
711
+ Check out graphviz documentation for more information on available attributes
712
+ https://graphviz.org/doc/info/attrs.html
707
713
node_formatters : dict, optional
708
714
A dictionary mapping node types to functions that return a dictionary of node attributes.
709
715
Check out graphviz documentation for more information on available
@@ -773,6 +779,7 @@ def model_to_graphviz(
773
779
save = save ,
774
780
figsize = figsize ,
775
781
dpi = dpi ,
782
+ graph_attr = graph_attr ,
776
783
node_formatters = node_formatters ,
777
784
create_plate_label = create_plate_label_with_dim_length
778
785
if include_dim_lengths
0 commit comments