Skip to content

Commit 641a1fe

Browse files
committed
Add split and fork methods to RngStream
1 parent 5a2da6d commit 641a1fe

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

flax/nnx/rnglib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def __call__(self) -> jax.Array:
124124
self.count[...] += 1
125125
return key
126126

127+
def key(self) -> jax.Array:
128+
return self()
129+
130+
def split(self, k: int):
131+
return self.fork(split=k)
132+
127133
def fork(self, *, split: int | tuple[int, ...] | None = None):
128134
key = self()
129135
if split is not None:

0 commit comments

Comments
 (0)