Skip to content

Commit b1f573b

Browse files
author
Flax Authors
committed
Merge pull request #5083 from dan-zheng:docs
PiperOrigin-RevId: 830991905
2 parents cd37bc9 + 1f67c8b commit b1f573b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

flax/nnx/graph.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2634,7 +2634,8 @@ def to_arrays(
26342634
>>> frozen_node = nnx.to_arrays(node)
26352635
>>> assert isinstance(frozen_node[0], jax.Array)
26362636
2637-
If the structure contains duplicate array refs, a ValueError is raised::
2637+
If ``allow_duplicates`` is ``False`` and the structure contains duplicate
2638+
array refs, raises ``ValueError``::
26382639
26392640
>>> shared_array = jax.new_ref(jnp.array(1.0))
26402641
>>> node = [shared_array, shared_array]
@@ -2660,8 +2661,12 @@ def to_arrays(
26602661
Args:
26612662
node: A structure potentially containing array refs.
26622663
only: A Filter to specify which array refs to freeze.
2664+
allow_duplicates: If True, allow duplicate array refs.
26632665
Returns:
26642666
A structure with the frozen arrays.
2667+
Raises:
2668+
ValueError: If duplicate array refs are found and `allow_duplicates` is
2669+
False.
26652670
"""
26662671
if not allow_duplicates and (
26672672
all_duplicates := find_duplicates(node, only=only)

0 commit comments

Comments
 (0)