@@ -465,6 +465,14 @@ def _state_unflatten(
465465)
466466
467467def map_state (f : tp .Callable [[tuple , tp .Any ], tp .Any ], state : State ) -> State :
468+ """Map ``f`` over :class:`State` object.
469+
470+ Arguments:
471+ f: A function to be mapped
472+ state: A :class:`State` object.
473+ Returns:
474+ New state :class:`State`.
475+ """
468476 flat_state = to_flat_state (state )
469477 result = [
470478 (path , f (path , variable_state )) for path , variable_state in flat_state
@@ -473,22 +481,44 @@ def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> State:
473481
474482
475483def to_flat_state (state : State ) -> FlatState :
484+ """Convert state into flat state
485+
486+ Arguments:
487+ state: A :class:`State` object.
488+ Returns:
489+ Flat state :class:`FlatState`
490+ """
476491 return FlatState (traversals .flatten_to_sequence (state ._mapping ), sort = True )
477492
478493
479494def from_flat_state (
480495 flat_state : tp .Mapping [PathParts , V ] | tp .Iterable [tuple [PathParts , V ]],
481496 * , cls = State , # for compatibility with State subclasses
482497) -> State :
498+ """Convert flat state object into :class:`State` object.
499+
500+ Arguments:
501+ flat_state: A :class:`FlatState` object.
502+ Returns:
503+ State :class:`State` object.
504+ """
483505 if not isinstance (flat_state , tp .Mapping ):
484506 flat_state = dict (flat_state )
485507 nested_state = traversals .unflatten_mapping (flat_state )
486508 return cls (nested_state )
487509
488510
489511def to_pure_dict (
490- state , extract_fn : ExtractValueFn | None = None
512+ state : State , extract_fn : ExtractValueFn | None = None
491513) -> dict [str , tp .Any ]:
514+ """Convert :class:`State` object into pure dictionary state.
515+
516+ Arguments:
517+ state: A :class:`State` object.
518+ extract_fn: optional extraction function.
519+ Returns:
520+ Pure dictionary.
521+ """
492522 # Works for nnx.Variable
493523 if extract_fn is None :
494524 extract_fn = lambda x : x .value if isinstance (x , variablelib .Variable ) else x
@@ -497,6 +527,33 @@ def to_pure_dict(
497527
498528
499529def restore_int_paths (pure_dict : dict [str , tp .Any ]):
530+ """Restore integer paths from string value in the dict.
531+ This method can be helpful when restoring the state from a checkpoint as
532+ pure dictionary:
533+
534+ Example::
535+
536+ >>> from flax import nnx
537+ >>> import orbax.checkpoint as ocp
538+ ...
539+ >>> model = nnx.List([nnx.Linear(10, 10, rngs=nnx.Rngs(0)) for _ in range(2)])
540+ >>> pure_dict_state = nnx.to_pure_dict(nnx.state(model))
541+ >>> list(pure_dict_state.keys())
542+ [0, 1]
543+ >>> checkpointer = ocp.StandardCheckpointer()
544+ >>> checkpointer.save('/tmp/checkpoint/pure_dict', pure_dict_state)
545+ >>> restored_pure_dict = checkpointer.restore('/tmp/checkpoint/pure_dict')
546+ >>> list(restored_pure_dict.keys())
547+ ['0', '1']
548+ >>> restored_pure_dict = nnx.restore_int_paths(restored_pure_dict)
549+ >>> list(restored_pure_dict.keys())
550+ [0, 1]
551+
552+ Arguments:
553+ pure_dict: state as pure dictionary
554+ Returns:
555+ state as pure dictionary with restored integers paths
556+ """
500557 def try_convert_int (x ):
501558 try :
502559 return int (x )
@@ -510,8 +567,15 @@ def try_convert_int(x):
510567 return traversals .unflatten_mapping (fixed )
511568
512569def replace_by_pure_dict (
513- state , pure_dict : dict [str , tp .Any ], replace_fn : SetValueFn | None = None
570+ state : State , pure_dict : dict [str , tp .Any ], replace_fn : SetValueFn | None = None
514571):
572+ """Replace input ``state`` values with ``pure_dict`` values.
573+
574+ Arguments:
575+ state: A :class:`State` object.
576+ pure_dict: pure dictionary with values to be used for replacement.
577+ replace_fn: optional replace function.
578+ """
515579 def try_convert_int (x ):
516580 try :
517581 return int (x )
@@ -559,7 +623,7 @@ def split_state(
559623def split_state ( # type: ignore[misc]
560624 state : State , first : filterlib .Filter , / , * filters : filterlib .Filter
561625) -> tp .Union [State , tuple [State , ...]]:
562- """Split a `` State`` into one or more `` State` `'s. The
626+ """Split a :class:` State` into one or more :class:` State`'s. The
563627 user must pass at least one ``Filter`` (i.e. :class:`Variable`),
564628 and the filters must be exhaustive (i.e. they must cover all
565629 :class:`Variable` types in the ``State``).
@@ -675,8 +739,8 @@ def merge_state(state: tp.Mapping, /, *states: tp.Mapping,
675739 ) -> State :
676740 """The inverse of :meth:`split() <flax.nnx.State.state.split>`.
677741
678- ``merge`` takes one or more `` State` `'s and creates
679- a new `` State` `.
742+ ``merge`` takes one or more :class:` State`'s and creates
743+ a new :class:` State`.
680744
681745 Example usage::
682746
@@ -699,10 +763,10 @@ def merge_state(state: tp.Mapping, /, *states: tp.Mapping,
699763 >>> assert (model.linear.bias[...] == jnp.array([1, 1, 1])).all()
700764
701765 Args:
702- state: A `` State` ` object.
703- *states: Additional `` State` ` objects.
766+ state: A :class:` State` object.
767+ *states: Additional :class:` State` objects.
704768 Returns:
705- The merged `` State` `.
769+ The merged :class:` State`.
706770 """
707771 if not states :
708772 if isinstance (state , cls ):
0 commit comments