@@ -439,6 +439,7 @@ def make_graph(
439439 figsize = None ,
440440 dpi = 300 ,
441441 node_formatters : NodeTypeFormatterMapping | None = None ,
442+ graph_attr : dict [str , Any ] | None = None ,
442443 create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
443444):
444445 """Make graphviz Digraph of PyMC model.
@@ -459,7 +460,7 @@ def make_graph(
459460 node_formatters = node_formatters or {}
460461 node_formatters = update_node_formatters (node_formatters )
461462
462- graph = graphviz .Digraph (name )
463+ graph = graphviz .Digraph (name , graph_attr = graph_attr )
463464 for plate in plates :
464465 if plate .dim_info :
465466 # must be preceded by 'cluster' to get a box around it
@@ -676,6 +677,7 @@ def model_to_graphviz(
676677 figsize : tuple [int , int ] | None = None ,
677678 dpi : int = 300 ,
678679 node_formatters : NodeTypeFormatterMapping | None = None ,
680+ graph_attr : dict [str , Any ] | None = None ,
679681 include_dim_lengths : bool = True ,
680682):
681683 """Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +706,10 @@ def model_to_graphviz(
704706 the size of the saved figure.
705707 dpi : int, optional
706708 Dots per inch. It only affects the resolution of the saved figure. The default is 300.
709+ graph_attr : dict, optional
710+ A dictionary of top-level layout attributes for graphviz
711+ Check out graphviz documentation for more information on available attributes
712+ https://graphviz.org/doc/info/attrs.html
707713 node_formatters : dict, optional
708714 A dictionary mapping node types to functions that return a dictionary of node attributes.
709715 Check out graphviz documentation for more information on available
@@ -773,6 +779,7 @@ def model_to_graphviz(
773779 save = save ,
774780 figsize = figsize ,
775781 dpi = dpi ,
782+ graph_attr = graph_attr ,
776783 node_formatters = node_formatters ,
777784 create_plate_label = create_plate_label_with_dim_length
778785 if include_dim_lengths
0 commit comments