Skip to content

Commit aa6ce6a

Browse files
committed
Replace key property with key_
1 parent 641a1fe commit aa6ce6a

File tree

7 files changed

+63
-64
lines changed

7 files changed

+63
-64
lines changed

flax/nnx/rnglib.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ def __init__(
112112

113113
count = jnp.zeros(key.shape, dtype=jnp.uint32)
114114
self.tag = tag
115-
self.key = RngKey(key, tag=tag)
115+
self.key_ = RngKey(key, tag=tag)
116116
self.count = RngCount(count, tag=tag)
117117

118118
def __call__(self) -> jax.Array:
119119
if not self.count.has_ref and not self.count._trace_state.is_valid():
120120
raise errors.TraceContextError(
121121
f'Cannot mutate {type(self).__name__} from a different trace level'
122122
)
123-
key = random.fold_in(self.key[...], self.count[...])
123+
key = random.fold_in(self.key_[...], self.count[...])
124124
self.count[...] += 1
125125
return key
126126

@@ -325,7 +325,7 @@ class Rngs(Pytree):
325325
``counter``. Every time a key is requested, the counter is incremented and the key is
326326
generated from the seed key and the counter by using ``jax.random.fold_in``.
327327
328-
To create an ``Rngs`` pass in an integer or ``jax.random.key`` to the
328+
To create an ``Rngs`` pass in an integer or ``jax.random.key_`` to the
329329
constructor as a keyword argument with the name of the stream. The key will be used as the
330330
starting seed for the stream, and the counter will be initialized to zero. Then call the
331331
stream to get a key::
@@ -378,7 +378,7 @@ def __init__(
378378
Args:
379379
default: the starting seed for the ``default`` stream, defaults to None.
380380
**rngs: keyword arguments specifying the starting seed for each stream.
381-
The key can be an integer or a ``jax.random.key``.
381+
The key can be an integer or a ``jax.random.key_``.
382382
"""
383383
if default is not None:
384384
if isinstance(default, tp.Mapping):
@@ -388,7 +388,7 @@ def __init__(
388388

389389
for tag, key in rngs.items():
390390
if isinstance(key, RngStream):
391-
key = key.key[...]
391+
key = key.key_[...]
392392
stream = RngStream(
393393
key=key,
394394
tag=tag,
@@ -457,8 +457,8 @@ def fork(
457457
>>> rngs = nnx.Rngs(params=1, dropout=2)
458458
>>> new_rngs = rngs.fork(split=5)
459459
...
460-
>>> assert new_rngs.params.key.shape == (5,)
461-
>>> assert new_rngs.dropout.key.shape == (5,)
460+
>>> assert new_rngs.params.key_.shape == (5,)
461+
>>> assert new_rngs.dropout.key_.shape == (5,)
462462
463463
``split`` also accepts a mapping of
464464
`Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to
@@ -471,9 +471,9 @@ def fork(
471471
... ...: (2, 5), # split anything else into 2x5 keys
472472
... })
473473
...
474-
>>> assert new_rngs.params.key.shape == (5,)
475-
>>> assert new_rngs.dropout.key.shape == ()
476-
>>> assert new_rngs.noise.key.shape == (2, 5)
474+
>>> assert new_rngs.params.key_.shape == (5,)
475+
>>> assert new_rngs.dropout.key_.shape == ()
476+
>>> assert new_rngs.noise.key_.shape == (2, 5)
477477
"""
478478
if split is None:
479479
split = {}
@@ -734,18 +734,18 @@ def split_rngs(
734734
...
735735
>>> rngs = nnx.Rngs(params=0, dropout=1)
736736
>>> _ = nnx.split_rngs(rngs, splits=5)
737-
>>> rngs.params.key.shape, rngs.dropout.key.shape
737+
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
738738
((5,), (5,))
739739
740740
>>> rngs = nnx.Rngs(params=0, dropout=1)
741741
>>> _ = nnx.split_rngs(rngs, splits=(2, 5))
742-
>>> rngs.params.key.shape, rngs.dropout.key.shape
742+
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
743743
((2, 5), (2, 5))
744744
745745
746746
>>> rngs = nnx.Rngs(params=0, dropout=1)
747747
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
748-
>>> rngs.params.key.shape, rngs.dropout.key.shape
748+
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
749749
((5,), ())
750750
751751
Once split, random state can be used with transforms like :func:`nnx.vmap`::
@@ -765,7 +765,7 @@ def split_rngs(
765765
... return Model(rngs)
766766
...
767767
>>> model = create_model(rngs)
768-
>>> model.dropout.rngs.key.shape
768+
>>> model.dropout.rngs.key_.shape
769769
()
770770
771771
``split_rngs`` returns a SplitBackups object that can be used to restore the
@@ -778,7 +778,7 @@ def split_rngs(
778778
>>> model = create_model(rngs)
779779
>>> nnx.restore_rngs(backups)
780780
...
781-
>>> model.dropout.rngs.key.shape
781+
>>> model.dropout.rngs.key_.shape
782782
()
783783
784784
SplitBackups can also be used as a context manager to automatically restore
@@ -789,7 +789,7 @@ def split_rngs(
789789
>>> with nnx.split_rngs(rngs, splits=5, only='params'):
790790
... model = create_model(rngs)
791791
...
792-
>>> model.dropout.rngs.key.shape
792+
>>> model.dropout.rngs.key_.shape
793793
()
794794
795795
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -801,7 +801,7 @@ def split_rngs(
801801
...
802802
>>> rngs = nnx.Rngs(params=0, dropout=1)
803803
>>> model = create_model(rngs)
804-
>>> model.dropout.rngs.key.shape
804+
>>> model.dropout.rngs.key_.shape
805805
()
806806
807807
@@ -828,18 +828,18 @@ def split_rngs_wrapper(*args, **kwargs):
828828
for path, stream in graph.iter_graph(node):
829829
if (
830830
isinstance(stream, RngStream)
831-
and predicate((*path, 'key'), stream.key)
831+
and predicate((*path, 'key'), stream.key_)
832832
and predicate((*path, 'count'), stream.count)
833833
):
834834
key = stream()
835-
backups.append((stream, stream.key.raw_value, stream.count.raw_value))
835+
backups.append((stream, stream.key_.raw_value, stream.count.raw_value))
836836
key = random.split(key, splits)
837837
if squeeze:
838838
key = key[0]
839-
if variablelib.is_array_ref(stream.key.raw_value):
840-
stream.key.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
839+
if variablelib.is_array_ref(stream.key_.raw_value):
840+
stream.key_.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
841841
else:
842-
stream.key.value = key
842+
stream.key_.value = key
843843
if squeeze:
844844
counts_shape = stream.count.shape
845845
elif isinstance(splits, int):
@@ -898,18 +898,18 @@ def fork_rngs(
898898
...
899899
>>> rngs = nnx.Rngs(params=0, dropout=1)
900900
>>> _ = nnx.fork_rngs(rngs, split=5)
901-
>>> rngs.params.key.shape, rngs.dropout.key.shape
901+
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
902902
((5,), (5,))
903903
904904
>>> rngs = nnx.Rngs(params=0, dropout=1)
905905
>>> _ = nnx.fork_rngs(rngs, split=(2, 5))
906-
>>> rngs.params.key.shape, rngs.dropout.key.shape
906+
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
907907
((2, 5), (2, 5))
908908
909909
910910
>>> rngs = nnx.Rngs(params=0, dropout=1)
911911
>>> _ = nnx.fork_rngs(rngs, split={'params': 5})
912-
>>> rngs.params.key.shape, rngs.dropout.key.shape
912+
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
913913
((5,), ())
914914
915915
Once forked, random state can be used with transforms like :func:`nnx.vmap`::
@@ -929,7 +929,7 @@ def fork_rngs(
929929
... return Model(rngs)
930930
...
931931
>>> model = create_model(rngs)
932-
>>> model.dropout.rngs.key.shape
932+
>>> model.dropout.rngs.key_.shape
933933
()
934934
935935
``fork_rngs`` returns a SplitBackups object that can be used to restore the
@@ -942,7 +942,7 @@ def fork_rngs(
942942
>>> model = create_model(rngs)
943943
>>> nnx.restore_rngs(backups)
944944
...
945-
>>> model.dropout.rngs.key.shape
945+
>>> model.dropout.rngs.key_.shape
946946
()
947947
948948
SplitBackups can also be used as a context manager to automatically restore
@@ -953,7 +953,7 @@ def fork_rngs(
953953
>>> with nnx.fork_rngs(rngs, split={'params': 5}):
954954
... model = create_model(rngs)
955955
...
956-
>>> model.dropout.rngs.key.shape
956+
>>> model.dropout.rngs.key_.shape
957957
()
958958
959959
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -965,7 +965,7 @@ def fork_rngs(
965965
...
966966
>>> rngs = nnx.Rngs(params=0, dropout=1)
967967
>>> model = create_model(rngs)
968-
>>> model.dropout.rngs.key.shape
968+
>>> model.dropout.rngs.key_.shape
969969
()
970970
"""
971971
if isinstance(node, Missing):
@@ -993,14 +993,14 @@ def fork_rngs_wrapper(*args, **kwargs):
993993
for predicate, splits in predicate_splits.items():
994994
if (
995995
isinstance(stream, RngStream)
996-
and predicate((*path, 'key'), stream.key)
996+
and predicate((*path, 'key'), stream.key_)
997997
and predicate((*path, 'count'), stream.count)
998998
):
999999
forked_stream = stream.fork(split=splits)
10001000
# backup the original stream state
1001-
backups.append((stream, stream.key.raw_value, stream.count.raw_value))
1001+
backups.append((stream, stream.key_.raw_value, stream.count.raw_value))
10021002
# apply the forked key and count to the original stream
1003-
stream.key.raw_value = forked_stream.key.raw_value
1003+
stream.key_.raw_value = forked_stream.key_.raw_value
10041004
stream.count.raw_value = forked_stream.count.raw_value
10051005

10061006
return SplitBackups(backups)
@@ -1010,7 +1010,7 @@ def backup_keys(node: tp.Any, /):
10101010
backups: list[StreamBackup] = []
10111011
for _, stream in graph.iter_graph(node):
10121012
if isinstance(stream, RngStream):
1013-
backups.append((stream, stream.key.raw_value))
1013+
backups.append((stream, stream.key_.raw_value))
10141014
return backups
10151015

10161016
def _scalars_only(
@@ -1055,7 +1055,7 @@ def reseed(
10551055
of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to
10561056
define a custom reseeding policy.
10571057
**stream_keys: a mapping of stream names to new keys. The keys can be
1058-
either integers or ``jax.random.key``.
1058+
either integers or ``jax.random.key_``.
10591059
10601060
Example::
10611061
@@ -1093,16 +1093,16 @@ def reseed(
10931093
rngs = Rngs(**stream_keys)
10941094
for path, stream in graph.iter_graph(node):
10951095
if isinstance(stream, RngStream):
1096-
if stream.key.tag in stream_keys:
1097-
key = rngs[stream.key.tag]()
1098-
key = policy(path, key, stream.key.shape)
1099-
stream.key.value = key
1096+
if stream.key_.tag in stream_keys:
1097+
key = rngs[stream.key_.tag]()
1098+
key = policy(path, key, stream.key_.shape)
1099+
stream.key_.value = key
11001100
stream.count.value = jnp.zeros(key.shape, dtype=jnp.uint32)
11011101

11021102

11031103
def restore_rngs(backups: tp.Iterable[StreamBackup], /):
11041104
for backup in backups:
11051105
stream = backup[0]
1106-
stream.key.raw_value = backup[1]
1106+
stream.key_.raw_value = backup[1]
11071107
if len(backup) == 3:
11081108
stream.count.raw_value = backup[2] # count

tests/nnx/bridge/module_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __call__(self):
149149
scope = bar.apply({}, rngs=1)
150150
self.assertIsNone(bar.scope)
151151

152-
self.assertEqual(scope.rngs.default.key[...], jax.random.key(1))
152+
self.assertEqual(scope.rngs.default.key_[...], jax.random.key(1))
153153
self.assertEqual(scope.rngs.default.count[...], 0)
154154

155155
class Baz(bridge.Module):
@@ -514,4 +514,3 @@ def __call__(self, x):
514514

515515
if __name__ == '__main__':
516516
absltest.main()
517-

tests/nnx/module_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def test_create_abstract(self):
556556
def test_create_abstract_stateful(self):
557557
linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0)))
558558

559-
assert linear.rngs.key.value == jax.ShapeDtypeStruct(
559+
assert linear.rngs.key_.value == jax.ShapeDtypeStruct(
560560
(), jax.random.key(0).dtype
561561
)
562562

tests/nnx/mutable_array_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def test_rngs_create(self):
662662
paths[1],
663663
(
664664
jax.tree_util.GetAttrKey('default'),
665-
jax.tree_util.GetAttrKey('key'),
665+
jax.tree_util.GetAttrKey('key_'),
666666
jax.tree_util.GetAttrKey('value'),
667667
),
668668
)

tests/nnx/nn/attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_keep_rngs(self, keep_rngs):
128128
if keep_rngs:
129129
_, _, nondiff = nnx.split(module, nnx.Param, ...)
130130
assert isinstance(nondiff['rngs']['count'], nnx.RngCount)
131-
assert isinstance(nondiff['rngs']['key'], nnx.RngKey)
131+
assert isinstance(nondiff['rngs']['key_'], nnx.RngKey)
132132
else:
133133
nnx.split(module, nnx.Param)
134134

tests/nnx/rngs_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def test_rng_stream(self):
4545

4646
key1 = rngs.params()
4747
self.assertEqual(rngs.params.count[...], 1)
48-
self.assertIs(rngs.params.key[...], key0)
48+
self.assertIs(rngs.params.key_[...], key0)
4949
self.assertFalse(jnp.allclose(key0, key1))
5050

5151
key2 = rngs.params()
5252
self.assertEqual(rngs.params.count[...], 2)
53-
self.assertIs(rngs.params.key[...], key0)
53+
self.assertIs(rngs.params.key_[...], key0)
5454
self.assertFalse(jnp.allclose(key1, key2))
5555

5656
def test_rng_trace_level_constraints(self):

0 commit comments

Comments
 (0)