Skip to content

Commit e8e6572

Browse files
committed
allow split tuples in Rngs.fork
1 parent 1d30dd6 commit e8e6572

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

flax/nnx/rnglib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ def fork(
623623
*,
624624
split: tp.Mapping[filterlib.Filter, int | tuple[int, ...]]
625625
| int
626+
| tuple[int, ...]
626627
| None = None,
627628
):
628629
"""Returns a new Rngs object with new unique RNG keys.
@@ -662,6 +663,8 @@ def fork(
662663
split = {}
663664
elif isinstance(split, int):
664665
split = {...: split}
666+
elif isinstance(split, tuple):
667+
split = {...: split}
665668

666669
split_predicates = {filterlib.to_predicate(k): v for k, v in split.items()}
667670
keys: dict[str, RngStream] = {}

0 commit comments

Comments
 (0)