@@ -112,15 +112,15 @@ def __init__(
112112
113113 count = jnp .zeros (key .shape , dtype = jnp .uint32 )
114114 self .tag = tag
115- self .key = RngKey (key , tag = tag )
115+ self .key_ = RngKey (key , tag = tag )
116116 self .count = RngCount (count , tag = tag )
117117
118118 def __call__ (self ) -> jax .Array :
119119 if not self .count .has_ref and not self .count ._trace_state .is_valid ():
120120 raise errors .TraceContextError (
121121 f'Cannot mutate { type (self ).__name__ } from a different trace level'
122122 )
123- key = random .fold_in (self .key [...], self .count [...])
123+ key = random .fold_in (self .key_ [...], self .count [...])
124124 self .count [...] += 1
125125 return key
126126
@@ -325,7 +325,7 @@ class Rngs(Pytree):
325325 ``counter``. Every time a key is requested, the counter is incremented and the key is
326326 generated from the seed key and the counter by using ``jax.random.fold_in``.
327327
328- To create an ``Rngs`` pass in an integer or ``jax.random.key `` to the
328+ To create an ``Rngs`` pass in an integer or ``jax.random.key_ `` to the
329329 constructor as a keyword argument with the name of the stream. The key will be used as the
330330 starting seed for the stream, and the counter will be initialized to zero. Then call the
331331 stream to get a key::
@@ -378,7 +378,7 @@ def __init__(
378378 Args:
379379 default: the starting seed for the ``default`` stream, defaults to None.
380380 **rngs: keyword arguments specifying the starting seed for each stream.
381- The key can be an integer or a ``jax.random.key ``.
381+ The key can be an integer or a ``jax.random.key_ ``.
382382 """
383383 if default is not None :
384384 if isinstance (default , tp .Mapping ):
@@ -388,7 +388,7 @@ def __init__(
388388
389389 for tag , key in rngs .items ():
390390 if isinstance (key , RngStream ):
391- key = key .key [...]
391+ key = key .key_ [...]
392392 stream = RngStream (
393393 key = key ,
394394 tag = tag ,
@@ -457,8 +457,8 @@ def fork(
457457 >>> rngs = nnx.Rngs(params=1, dropout=2)
458458 >>> new_rngs = rngs.fork(split=5)
459459 ...
460- >>> assert new_rngs.params.key .shape == (5,)
461- >>> assert new_rngs.dropout.key .shape == (5,)
460+ >>> assert new_rngs.params.key_ .shape == (5,)
461+ >>> assert new_rngs.dropout.key_ .shape == (5,)
462462
463463 ``split`` also accepts a mapping of
464464 `Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to
@@ -471,9 +471,9 @@ def fork(
471471 ... ...: (2, 5), # split anything else into 2x5 keys
472472 ... })
473473 ...
474- >>> assert new_rngs.params.key .shape == (5,)
475- >>> assert new_rngs.dropout.key .shape == ()
476- >>> assert new_rngs.noise.key .shape == (2, 5)
474+ >>> assert new_rngs.params.key_ .shape == (5,)
475+ >>> assert new_rngs.dropout.key_ .shape == ()
476+ >>> assert new_rngs.noise.key_ .shape == (2, 5)
477477 """
478478 if split is None :
479479 split = {}
@@ -734,18 +734,18 @@ def split_rngs(
734734 ...
735735 >>> rngs = nnx.Rngs(params=0, dropout=1)
736736 >>> _ = nnx.split_rngs(rngs, splits=5)
737- >>> rngs.params.key .shape, rngs.dropout.key .shape
737+ >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
738738 ((5,), (5,))
739739
740740 >>> rngs = nnx.Rngs(params=0, dropout=1)
741741 >>> _ = nnx.split_rngs(rngs, splits=(2, 5))
742- >>> rngs.params.key .shape, rngs.dropout.key .shape
742+ >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
743743 ((2, 5), (2, 5))
744744
745745
746746 >>> rngs = nnx.Rngs(params=0, dropout=1)
747747 >>> _ = nnx.split_rngs(rngs, splits=5, only='params')
748- >>> rngs.params.key .shape, rngs.dropout.key .shape
748+ >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
749749 ((5,), ())
750750
751751 Once split, random state can be used with transforms like :func:`nnx.vmap`::
@@ -765,7 +765,7 @@ def split_rngs(
765765 ... return Model(rngs)
766766 ...
767767 >>> model = create_model(rngs)
768- >>> model.dropout.rngs.key .shape
768+ >>> model.dropout.rngs.key_ .shape
769769 ()
770770
771771 ``split_rngs`` returns a SplitBackups object that can be used to restore the
@@ -778,7 +778,7 @@ def split_rngs(
778778 >>> model = create_model(rngs)
779779 >>> nnx.restore_rngs(backups)
780780 ...
781- >>> model.dropout.rngs.key .shape
781+ >>> model.dropout.rngs.key_ .shape
782782 ()
783783
784784 SplitBackups can also be used as a context manager to automatically restore
@@ -789,7 +789,7 @@ def split_rngs(
789789 >>> with nnx.split_rngs(rngs, splits=5, only='params'):
790790 ... model = create_model(rngs)
791791 ...
792- >>> model.dropout.rngs.key .shape
792+ >>> model.dropout.rngs.key_ .shape
793793 ()
794794
795795 >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -801,7 +801,7 @@ def split_rngs(
801801 ...
802802 >>> rngs = nnx.Rngs(params=0, dropout=1)
803803 >>> model = create_model(rngs)
804- >>> model.dropout.rngs.key .shape
804+ >>> model.dropout.rngs.key_ .shape
805805 ()
806806
807807
@@ -828,18 +828,18 @@ def split_rngs_wrapper(*args, **kwargs):
828828 for path , stream in graph .iter_graph (node ):
829829 if (
830830 isinstance (stream , RngStream )
831- and predicate ((* path , 'key' ), stream .key )
831+ and predicate ((* path , 'key' ), stream .key_ )
832832 and predicate ((* path , 'count' ), stream .count )
833833 ):
834834 key = stream ()
835- backups .append ((stream , stream .key .raw_value , stream .count .raw_value ))
835+ backups .append ((stream , stream .key_ .raw_value , stream .count .raw_value ))
836836 key = random .split (key , splits )
837837 if squeeze :
838838 key = key [0 ]
839- if variablelib .is_array_ref (stream .key .raw_value ):
840- stream .key .raw_value = variablelib .new_ref (key ) # type: ignore[assignment]
839+ if variablelib .is_array_ref (stream .key_ .raw_value ):
840+ stream .key_ .raw_value = variablelib .new_ref (key ) # type: ignore[assignment]
841841 else :
842- stream .key .value = key
842+ stream .key_ .value = key
843843 if squeeze :
844844 counts_shape = stream .count .shape
845845 elif isinstance (splits , int ):
@@ -898,18 +898,18 @@ def fork_rngs(
898898 ...
899899 >>> rngs = nnx.Rngs(params=0, dropout=1)
900900 >>> _ = nnx.fork_rngs(rngs, split=5)
901- >>> rngs.params.key .shape, rngs.dropout.key .shape
901+ >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
902902 ((5,), (5,))
903903
904904 >>> rngs = nnx.Rngs(params=0, dropout=1)
905905 >>> _ = nnx.fork_rngs(rngs, split=(2, 5))
906- >>> rngs.params.key .shape, rngs.dropout.key .shape
906+ >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
907907 ((2, 5), (2, 5))
908908
909909
910910 >>> rngs = nnx.Rngs(params=0, dropout=1)
911911 >>> _ = nnx.fork_rngs(rngs, split={'params': 5})
912- >>> rngs.params.key .shape, rngs.dropout.key .shape
912+ >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
913913 ((5,), ())
914914
915915 Once forked, random state can be used with transforms like :func:`nnx.vmap`::
@@ -929,7 +929,7 @@ def fork_rngs(
929929 ... return Model(rngs)
930930 ...
931931 >>> model = create_model(rngs)
932- >>> model.dropout.rngs.key .shape
932+ >>> model.dropout.rngs.key_ .shape
933933 ()
934934
935935 ``fork_rngs`` returns a SplitBackups object that can be used to restore the
@@ -942,7 +942,7 @@ def fork_rngs(
942942 >>> model = create_model(rngs)
943943 >>> nnx.restore_rngs(backups)
944944 ...
945- >>> model.dropout.rngs.key .shape
945+ >>> model.dropout.rngs.key_ .shape
946946 ()
947947
948948 SplitBackups can also be used as a context manager to automatically restore
@@ -953,7 +953,7 @@ def fork_rngs(
953953 >>> with nnx.fork_rngs(rngs, split={'params': 5}):
954954 ... model = create_model(rngs)
955955 ...
956- >>> model.dropout.rngs.key .shape
956+ >>> model.dropout.rngs.key_ .shape
957957 ()
958958
959959 >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -965,7 +965,7 @@ def fork_rngs(
965965 ...
966966 >>> rngs = nnx.Rngs(params=0, dropout=1)
967967 >>> model = create_model(rngs)
968- >>> model.dropout.rngs.key .shape
968+ >>> model.dropout.rngs.key_ .shape
969969 ()
970970 """
971971 if isinstance (node , Missing ):
@@ -993,14 +993,14 @@ def fork_rngs_wrapper(*args, **kwargs):
993993 for predicate , splits in predicate_splits .items ():
994994 if (
995995 isinstance (stream , RngStream )
996- and predicate ((* path , 'key' ), stream .key )
996+ and predicate ((* path , 'key' ), stream .key_ )
997997 and predicate ((* path , 'count' ), stream .count )
998998 ):
999999 forked_stream = stream .fork (split = splits )
10001000 # backup the original stream state
1001- backups .append ((stream , stream .key .raw_value , stream .count .raw_value ))
1001+ backups .append ((stream , stream .key_ .raw_value , stream .count .raw_value ))
10021002 # apply the forked key and count to the original stream
1003- stream .key .raw_value = forked_stream .key .raw_value
1003+ stream .key_ .raw_value = forked_stream .key_ .raw_value
10041004 stream .count .raw_value = forked_stream .count .raw_value
10051005
10061006 return SplitBackups (backups )
@@ -1010,7 +1010,7 @@ def backup_keys(node: tp.Any, /):
10101010 backups : list [StreamBackup ] = []
10111011 for _ , stream in graph .iter_graph (node ):
10121012 if isinstance (stream , RngStream ):
1013- backups .append ((stream , stream .key .raw_value ))
1013+ backups .append ((stream , stream .key_ .raw_value ))
10141014 return backups
10151015
10161016def _scalars_only (
@@ -1055,7 +1055,7 @@ def reseed(
10551055 of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to
10561056 define a custom reseeding policy.
10571057 **stream_keys: a mapping of stream names to new keys. The keys can be
1058- either integers or ``jax.random.key ``.
1058+ either integers or ``jax.random.key_ ``.
10591059
10601060 Example::
10611061
@@ -1093,16 +1093,16 @@ def reseed(
10931093 rngs = Rngs (** stream_keys )
10941094 for path , stream in graph .iter_graph (node ):
10951095 if isinstance (stream , RngStream ):
1096- if stream .key .tag in stream_keys :
1097- key = rngs [stream .key .tag ]()
1098- key = policy (path , key , stream .key .shape )
1099- stream .key .value = key
1096+ if stream .key_ .tag in stream_keys :
1097+ key = rngs [stream .key_ .tag ]()
1098+ key = policy (path , key , stream .key_ .shape )
1099+ stream .key_ .value = key
11001100 stream .count .value = jnp .zeros (key .shape , dtype = jnp .uint32 )
11011101
11021102
11031103def restore_rngs (backups : tp .Iterable [StreamBackup ], / ):
11041104 for backup in backups :
11051105 stream = backup [0 ]
1106- stream .key .raw_value = backup [1 ]
1106+ stream .key_ .raw_value = backup [1 ]
11071107 if len (backup ) == 3 :
11081108 stream .count .raw_value = backup [2 ] # count
0 commit comments