Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 67 additions & 40 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

57 changes: 28 additions & 29 deletions docs_nnx/guides/checkpointing.md

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions flax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax import core, lax
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
from typing import Any


def _pmap_device_order():
Expand Down Expand Up @@ -316,3 +317,41 @@ def unpad(x):
return out if static_return else jax.tree_util.tree_map(unpad, out)

return pad_shard_unpad_wrapper


class _DictOrList(dict):
"""Dictionary that should be converted to a list."""
is_list: bool = False

def _to_pytree(a):
if not isinstance(a, _DictOrList):
return a
if a.is_list:
return [_to_pytree(v) for k, v in sorted(a.items())]
else:
return {k: _to_pytree(v) for k, v in a.items()}

def _path_ix(a):
return a.key if isinstance(a, jax.tree_util.DictKey) else a.idx

def build_tree_from_paths(paths_and_leaves: list[tuple[jax.tree_util.KeyPath, Any]]):
"""
Inverse of ``jax.tree.leaves_with_path``. Builds a PyTree from a list of (path, leaf) pairs.
"""
root = _DictOrList()
for path, leaf in paths_and_leaves:
if not path: continue
current = root

# Navigate/create structure following the path
for key_entry in path[:-1]:
k = _path_ix(key_entry)
if k not in current:
current[k] = _DictOrList()
current.is_list = isinstance(key_entry, jax.tree_util.SequenceKey)
current = current[k]

# Set the leaf value
current.is_list = isinstance(path[-1], jax.tree_util.SequenceKey)
current[_path_ix(path[-1])] = leaf
return _to_pytree(root)
90 changes: 51 additions & 39 deletions flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,24 @@ def __init__(

count = jnp.zeros(key.shape, dtype=jnp.uint32)
self.tag = tag
self.key = RngKey(key, tag=tag)
self.base_key = RngKey(key, tag=tag)
self.count = RngCount(count, tag=tag)

def __call__(self) -> jax.Array:
if not self.count.has_ref and not self.count._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
key = random.fold_in(self.key[...], self.count[...])
key = random.fold_in(self.base_key[...], self.count[...])
self.count[...] += 1
return key

def key(self) -> jax.Array:
return self()

def split(self, k: int):
return self.fork(split=k)

def fork(self, *, split: int | tuple[int, ...] | None = None):
key = self()
if split is not None:
Expand Down Expand Up @@ -319,7 +325,7 @@ class Rngs(Pytree):
``counter``. Every time a key is requested, the counter is incremented and the key is
generated from the seed key and the counter by using ``jax.random.fold_in``.

To create an ``Rngs`` pass in an integer or ``jax.random.key`` to the
To create an ``Rngs`` pass in an integer or ``jax.random.base_key`` to the
constructor as a keyword argument with the name of the stream. The key will be used as the
starting seed for the stream, and the counter will be initialized to zero. Then call the
stream to get a key::
Expand Down Expand Up @@ -372,7 +378,7 @@ def __init__(
Args:
default: the starting seed for the ``default`` stream, defaults to None.
**rngs: keyword arguments specifying the starting seed for each stream.
The key can be an integer or a ``jax.random.key``.
The key can be an integer or a ``jax.random.base_key``.
"""
if default is not None:
if isinstance(default, tp.Mapping):
Expand All @@ -382,7 +388,7 @@ def __init__(

for tag, key in rngs.items():
if isinstance(key, RngStream):
key = key.key[...]
key = key.base_key[...]
stream = RngStream(
key=key,
tag=tag,
Expand All @@ -409,6 +415,9 @@ def __getattr__(self, name: str):
def __call__(self):
return self.default()

def key(self):
return self.default()

def __iter__(self) -> tp.Iterator[str]:
for name, stream in vars(self).items():
if isinstance(stream, RngStream):
Expand All @@ -427,6 +436,9 @@ def items(self):
if isinstance(stream, RngStream):
yield name, stream

def split(self, splits: int):
return self.fork(split=splits)

def fork(
self,
/,
Expand All @@ -451,8 +463,8 @@ def fork(
>>> rngs = nnx.Rngs(params=1, dropout=2)
>>> new_rngs = rngs.fork(split=5)
...
>>> assert new_rngs.params.key.shape == (5,)
>>> assert new_rngs.dropout.key.shape == (5,)
>>> assert new_rngs.params.base_key.shape == (5,)
>>> assert new_rngs.dropout.base_key.shape == (5,)

``split`` also accepts a mapping of
`Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to
Expand All @@ -465,9 +477,9 @@ def fork(
... ...: (2, 5), # split anything else into 2x5 keys
... })
...
>>> assert new_rngs.params.key.shape == (5,)
>>> assert new_rngs.dropout.key.shape == ()
>>> assert new_rngs.noise.key.shape == (2, 5)
>>> assert new_rngs.params.base_key.shape == (5,)
>>> assert new_rngs.dropout.base_key.shape == ()
>>> assert new_rngs.noise.base_key.shape == (2, 5)
"""
if split is None:
split = {}
Expand Down Expand Up @@ -728,18 +740,18 @@ def split_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5)
>>> rngs.params.key.shape, rngs.dropout.key.shape
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), (5,))

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=(2, 5))
>>> rngs.params.key.shape, rngs.dropout.key.shape
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((2, 5), (2, 5))


>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
>>> rngs.params.key.shape, rngs.dropout.key.shape
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), ())

Once split, random state can be used with transforms like :func:`nnx.vmap`::
Expand All @@ -759,7 +771,7 @@ def split_rngs(
... return Model(rngs)
...
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()

``split_rngs`` returns a SplitBackups object that can be used to restore the
Expand All @@ -772,7 +784,7 @@ def split_rngs(
>>> model = create_model(rngs)
>>> nnx.restore_rngs(backups)
...
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()

SplitBackups can also be used as a context manager to automatically restore
Expand All @@ -783,7 +795,7 @@ def split_rngs(
>>> with nnx.split_rngs(rngs, splits=5, only='params'):
... model = create_model(rngs)
...
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()

>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
Expand All @@ -795,7 +807,7 @@ def split_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()


Expand All @@ -822,18 +834,18 @@ def split_rngs_wrapper(*args, **kwargs):
for path, stream in graph.iter_graph(node):
if (
isinstance(stream, RngStream)
and predicate((*path, 'key'), stream.key)
and predicate((*path, 'key'), stream.base_key)
and predicate((*path, 'count'), stream.count)
):
key = stream()
backups.append((stream, stream.key.raw_value, stream.count.raw_value))
backups.append((stream, stream.base_key.raw_value, stream.count.raw_value))
key = random.split(key, splits)
if squeeze:
key = key[0]
if variablelib.is_array_ref(stream.key.raw_value):
stream.key.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
if variablelib.is_array_ref(stream.base_key.raw_value):
stream.base_key.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
else:
stream.key.value = key
stream.base_key.value = key
if squeeze:
counts_shape = stream.count.shape
elif isinstance(splits, int):
Expand Down Expand Up @@ -892,18 +904,18 @@ def fork_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split=5)
>>> rngs.params.key.shape, rngs.dropout.key.shape
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), (5,))

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split=(2, 5))
>>> rngs.params.key.shape, rngs.dropout.key.shape
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((2, 5), (2, 5))


>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split={'params': 5})
>>> rngs.params.key.shape, rngs.dropout.key.shape
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), ())

Once forked, random state can be used with transforms like :func:`nnx.vmap`::
Expand All @@ -923,7 +935,7 @@ def fork_rngs(
... return Model(rngs)
...
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()

``fork_rngs`` returns a SplitBackups object that can be used to restore the
Expand All @@ -936,7 +948,7 @@ def fork_rngs(
>>> model = create_model(rngs)
>>> nnx.restore_rngs(backups)
...
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()

SplitBackups can also be used as a context manager to automatically restore
Expand All @@ -947,7 +959,7 @@ def fork_rngs(
>>> with nnx.fork_rngs(rngs, split={'params': 5}):
... model = create_model(rngs)
...
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()

>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
Expand All @@ -959,7 +971,7 @@ def fork_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
>>> model.dropout.rngs.base_key.shape
()
"""
if isinstance(node, Missing):
Expand Down Expand Up @@ -987,14 +999,14 @@ def fork_rngs_wrapper(*args, **kwargs):
for predicate, splits in predicate_splits.items():
if (
isinstance(stream, RngStream)
and predicate((*path, 'key'), stream.key)
and predicate((*path, 'key'), stream.base_key)
and predicate((*path, 'count'), stream.count)
):
forked_stream = stream.fork(split=splits)
# backup the original stream state
backups.append((stream, stream.key.raw_value, stream.count.raw_value))
backups.append((stream, stream.base_key.raw_value, stream.count.raw_value))
# apply the forked key and count to the original stream
stream.key.raw_value = forked_stream.key.raw_value
stream.base_key.raw_value = forked_stream.base_key.raw_value
stream.count.raw_value = forked_stream.count.raw_value

return SplitBackups(backups)
Expand All @@ -1004,7 +1016,7 @@ def backup_keys(node: tp.Any, /):
backups: list[StreamBackup] = []
for _, stream in graph.iter_graph(node):
if isinstance(stream, RngStream):
backups.append((stream, stream.key.raw_value))
backups.append((stream, stream.base_key.raw_value))
return backups

def _scalars_only(
Expand Down Expand Up @@ -1049,7 +1061,7 @@ def reseed(
of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to
define a custom reseeding policy.
**stream_keys: a mapping of stream names to new keys. The keys can be
either integers or ``jax.random.key``.
either integers or ``jax.random.base_key``.

Example::

Expand Down Expand Up @@ -1087,16 +1099,16 @@ def reseed(
rngs = Rngs(**stream_keys)
for path, stream in graph.iter_graph(node):
if isinstance(stream, RngStream):
if stream.key.tag in stream_keys:
key = rngs[stream.key.tag]()
key = policy(path, key, stream.key.shape)
stream.key.value = key
if stream.base_key.tag in stream_keys:
key = rngs[stream.base_key.tag]()
key = policy(path, key, stream.base_key.shape)
stream.base_key.value = key
stream.count.value = jnp.zeros(key.shape, dtype=jnp.uint32)


def restore_rngs(backups: tp.Iterable[StreamBackup], /):
for backup in backups:
stream = backup[0]
stream.key.raw_value = backup[1]
stream.base_key.raw_value = backup[1]
if len(backup) == 3:
stream.count.raw_value = backup[2] # count
3 changes: 1 addition & 2 deletions tests/nnx/bridge/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __call__(self):
scope = bar.apply({}, rngs=1)
self.assertIsNone(bar.scope)

self.assertEqual(scope.rngs.default.key[...], jax.random.key(1))
self.assertEqual(scope.rngs.default.base_key[...], jax.random.key(1))
self.assertEqual(scope.rngs.default.count[...], 0)

class Baz(bridge.Module):
Expand Down Expand Up @@ -514,4 +514,3 @@ def __call__(self, x):

if __name__ == '__main__':
absltest.main()

2 changes: 1 addition & 1 deletion tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def test_create_abstract(self):
def test_create_abstract_stateful(self):
linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0)))

assert linear.rngs.key.value == jax.ShapeDtypeStruct(
assert linear.rngs.base_key.value == jax.ShapeDtypeStruct(
(), jax.random.key(0).dtype
)

Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,15 +654,15 @@ def test_rngs_create(self):
paths[0],
(
jax.tree_util.GetAttrKey('default'),
jax.tree_util.GetAttrKey('count'),
jax.tree_util.GetAttrKey('base_key'),
jax.tree_util.GetAttrKey('value'),
),
)
self.assertEqual(
paths[1],
(
jax.tree_util.GetAttrKey('default'),
jax.tree_util.GetAttrKey('key'),
jax.tree_util.GetAttrKey('count'),
jax.tree_util.GetAttrKey('value'),
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_keep_rngs(self, keep_rngs):
if keep_rngs:
_, _, nondiff = nnx.split(module, nnx.Param, ...)
assert isinstance(nondiff['rngs']['count'], nnx.RngCount)
assert isinstance(nondiff['rngs']['key'], nnx.RngKey)
assert isinstance(nondiff['rngs']['base_key'], nnx.RngKey)
else:
nnx.split(module, nnx.Param)

Expand Down
Loading
Loading