@@ -938,7 +938,7 @@ def f_pmapped(*args, **kwargs):
938
938
msg = ("soft_pmap mapped axis size must be divisble by the number of "
939
939
"XLA devices (or be less than or equal to that number), but got "
940
940
"an axis size of {} with {} devices." )
941
- raise ValueError (msg .format (axis_size , pxla .pxla . unmapped_device_count ()))
941
+ raise ValueError (msg .format (axis_size , pxla .unmapped_device_count ()))
942
942
num_chunks = axis_size // chunk_size
943
943
944
944
reshaped_args = [_reshape_split (num_chunks , x ) for x in args_flat ]
@@ -1922,15 +1922,12 @@ def jaxpr_to_graphviz(jaxpr, consts):
1922
1922
fragment .extend (map (constant_node , jaxpr .constvars , consts ))
1923
1923
1924
1924
for eqn in jaxpr .eqns :
1925
- if eqn .destructure :
1926
- id_name = next (id_names )
1927
- fragment .append (function_node (id_name , eqn .primitive .name ))
1928
- fragment .extend (edge (invar , id_name ) for invar in eqn .invars )
1929
- fragment .extend (edge (id_name , outvar ) for outvar in eqn .outvars )
1930
- else :
1931
- fragment .append (function_node (eqn .outvars [0 ], eqn .primitive .name ))
1932
- fragment .extend (edge (invar , eqn .outvars [0 ]) for invar in eqn .invars )
1933
- fragment .append (outvar_node (jaxpr .outvar , "out" ))
1925
+ id_name = next (id_names )
1926
+ fragment .append (function_node (id_name , eqn .primitive .name ))
1927
+ fragment .extend (edge (invar , id_name ) for invar in eqn .invars )
1928
+ fragment .extend (edge (id_name , outvar ) for outvar in eqn .outvars )
1929
+ for ov in jaxpr .outvars :
1930
+ fragment .append (outvar_node (ov , "out" ))
1934
1931
return graph ('' .join (fragment ))
1935
1932
1936
1933
edge = '{} -> {} [color=gray30];\n ' .format
@@ -1944,8 +1941,8 @@ def jaxpr_to_graphviz(jaxpr, consts):
1944
1941
@wraps (fun )
1945
1942
def graphviz_maker (* args , ** kwargs ):
1946
1943
wrapped = lu .wrap_init (fun , kwargs )
1947
- jax_args , in_trees = unzip2 ( map ( pytree_to_jaxtupletree , args ))
1948
- jaxtree_fun , out_tree = pytree_fun_to_jaxtupletree_fun (wrapped , in_trees )
1944
+ jax_args , in_tree = tree_flatten (( args , kwargs ))
1945
+ jaxtree_fun , out_tree = flatten_fun (wrapped , in_tree )
1949
1946
pvals = map (pv_like , jax_args )
1950
1947
jaxpr , _ , consts = pe .trace_to_jaxpr (jaxtree_fun , pvals )
1951
1948
return jaxpr_to_graphviz (jaxpr , consts )
0 commit comments