Skip to content

Commit 57c2cc2

Browse files
author
Flax Authors
committed
Merge pull request #5037 from google:expose-restore_int_paths
PiperOrigin-RevId: 822201398
2 parents 4d1cfbb + 2ea9f3b commit 57c2cc2

File tree

3 files changed

+88
-9
lines changed

3 files changed

+88
-9
lines changed

docs_nnx/api_reference/flax.nnx/state.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,17 @@ state
66

77

88
.. autoclass:: State
9-
:members:
9+
:members:
10+
11+
.. autoclass:: FlatState
12+
:members:
13+
14+
.. autofunction:: filter_state
15+
.. autofunction:: from_flat_state
16+
.. autofunction:: map_state
17+
.. autofunction:: merge_state
18+
.. autofunction:: replace_by_pure_dict
19+
.. autofunction:: restore_int_paths
20+
.. autofunction:: to_flat_state
21+
.. autofunction:: to_pure_dict
22+
.. autofunction:: split_state

flax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,13 @@
142142
from .spmd import get_named_sharding as get_named_sharding
143143
from .spmd import with_partitioning as with_partitioning
144144
from .spmd import get_abstract_model as get_abstract_model
145+
from .statelib import FlatState as FlatState
145146
from .statelib import State as State
146147
from .statelib import to_flat_state as to_flat_state
147148
from .statelib import from_flat_state as from_flat_state
148149
from .statelib import to_pure_dict as to_pure_dict
149150
from .statelib import replace_by_pure_dict as replace_by_pure_dict
151+
from .statelib import restore_int_paths as restore_int_paths
150152
from .statelib import filter_state as filter_state
151153
from .statelib import merge_state as merge_state
152154
from .statelib import split_state as split_state

flax/nnx/statelib.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,14 @@ def _state_unflatten(
465465
)
466466

467467
def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> State:
468+
"""Map ``f`` over :class:`State` object.
469+
470+
Arguments:
471+
f: A function to be mapped
472+
state: A :class:`State` object.
473+
Returns:
474+
New state :class:`State`.
475+
"""
468476
flat_state = to_flat_state(state)
469477
result = [
470478
(path, f(path, variable_state)) for path, variable_state in flat_state
@@ -473,22 +481,44 @@ def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> State:
473481

474482

475483
def to_flat_state(state: State) -> FlatState:
484+
"""Convert state into flat state
485+
486+
Arguments:
487+
state: A :class:`State` object.
488+
Returns:
489+
Flat state :class:`FlatState`
490+
"""
476491
return FlatState(traversals.flatten_to_sequence(state._mapping), sort=True)
477492

478493

479494
def from_flat_state(
480495
flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]],
481496
*, cls = State, # for compatibility with State subclasses
482497
) -> State:
498+
"""Convert flat state object into :class:`State` object.
499+
500+
Arguments:
501+
flat_state: A :class:`FlatState` object.
502+
Returns:
503+
State :class:`State` object.
504+
"""
483505
if not isinstance(flat_state, tp.Mapping):
484506
flat_state = dict(flat_state)
485507
nested_state = traversals.unflatten_mapping(flat_state)
486508
return cls(nested_state)
487509

488510

489511
def to_pure_dict(
490-
state, extract_fn: ExtractValueFn | None = None
512+
state: State, extract_fn: ExtractValueFn | None = None
491513
) -> dict[str, tp.Any]:
514+
"""Convert :class:`State` object into pure dictionary state.
515+
516+
Arguments:
517+
state: A :class:`State` object.
518+
extract_fn: optional extraction function.
519+
Returns:
520+
Pure dictionary.
521+
"""
492522
# Works for nnx.Variable
493523
if extract_fn is None:
494524
extract_fn = lambda x: x.value if isinstance(x, variablelib.Variable) else x
@@ -497,6 +527,33 @@ def to_pure_dict(
497527

498528

499529
def restore_int_paths(pure_dict: dict[str, tp.Any]):
530+
"""Restore integer paths from string value in the dict.
531+
This method can be helpful when restoring the state from a checkpoint as
532+
pure dictionary:
533+
534+
Example::
535+
536+
>>> from flax import nnx
537+
>>> import orbax.checkpoint as ocp
538+
...
539+
>>> model = nnx.List([nnx.Linear(10, 10, rngs=nnx.Rngs(0)) for _ in range(2)])
540+
>>> pure_dict_state = nnx.to_pure_dict(nnx.state(model))
541+
>>> list(pure_dict_state.keys())
542+
[0, 1]
543+
>>> checkpointer = ocp.StandardCheckpointer()
544+
>>> checkpointer.save('/tmp/checkpoint/pure_dict', pure_dict_state)
545+
>>> restored_pure_dict = checkpointer.restore('/tmp/checkpoint/pure_dict')
546+
>>> list(restored_pure_dict.keys())
547+
['0', '1']
548+
>>> restored_pure_dict = nnx.restore_int_paths(restored_pure_dict)
549+
>>> list(restored_pure_dict.keys())
550+
[0, 1]
551+
552+
Arguments:
553+
pure_dict: state as pure dictionary
554+
Returns:
555+
state as pure dictionary with restored integers paths
556+
"""
500557
def try_convert_int(x):
501558
try:
502559
return int(x)
@@ -510,8 +567,15 @@ def try_convert_int(x):
510567
return traversals.unflatten_mapping(fixed)
511568

512569
def replace_by_pure_dict(
513-
state, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None
570+
state: State, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None
514571
):
572+
"""Replace input ``state`` values with ``pure_dict`` values.
573+
574+
Arguments:
575+
state: A :class:`State` object.
576+
pure_dict: pure dictionary with values to be used for replacement.
577+
replace_fn: optional replace function.
578+
"""
515579
def try_convert_int(x):
516580
try:
517581
return int(x)
@@ -559,7 +623,7 @@ def split_state(
559623
def split_state( # type: ignore[misc]
560624
state: State, first: filterlib.Filter, /, *filters: filterlib.Filter
561625
) -> tp.Union[State, tuple[State, ...]]:
562-
"""Split a ``State`` into one or more ``State``'s. The
626+
"""Split a :class:`State` into one or more :class:`State`'s. The
563627
user must pass at least one ``Filter`` (i.e. :class:`Variable`),
564628
and the filters must be exhaustive (i.e. they must cover all
565629
:class:`Variable` types in the ``State``).
@@ -675,8 +739,8 @@ def merge_state(state: tp.Mapping, /, *states: tp.Mapping,
675739
) -> State:
676740
"""The inverse of :meth:`split() <flax.nnx.State.state.split>`.
677741
678-
``merge`` takes one or more ``State``'s and creates
679-
a new ``State``.
742+
``merge`` takes one or more :class:`State`'s and creates
743+
a new :class:`State`.
680744
681745
Example usage::
682746
@@ -699,10 +763,10 @@ def merge_state(state: tp.Mapping, /, *states: tp.Mapping,
699763
>>> assert (model.linear.bias[...] == jnp.array([1, 1, 1])).all()
700764
701765
Args:
702-
state: A ``State`` object.
703-
*states: Additional ``State`` objects.
766+
state: A :class:`State` object.
767+
*states: Additional :class:`State` objects.
704768
Returns:
705-
The merged ``State``.
769+
The merged :class:`State`.
706770
"""
707771
if not states:
708772
if isinstance(state, cls):

0 commit comments

Comments
 (0)