Skip to content

Commit b5791b4

Browse files
committed
hijax Variable
1 parent e8e6572 commit b5791b4

File tree

8 files changed

+766
-280
lines changed

8 files changed

+766
-280
lines changed

flax/nnx/extract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jax
1919

2020
from flax import struct
21+
from flax import typing
2122
from flax.nnx.pytreelib import Pytree
2223
from flax.typing import Missing, PathParts
2324
from flax.nnx import graph, variablelib
@@ -35,7 +36,7 @@ class PrefixMapping(abc.ABC):
3536
@abc.abstractmethod
3637
def map_prefix(
3738
self,
38-
path: variablelib.PathParts,
39+
path: typing.PathParts,
3940
variable: variablelib.Variable,
4041
/,
4142
) -> tp.Any: ...

flax/nnx/rnglib.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def __init__(
7070
self.count = RngCount(count, tag=tag)
7171

7272
def __call__(self) -> jax.Array:
73-
if not self.count.has_ref and not self.count._trace_state.is_valid():
73+
if (
74+
not self.count.has_ref
75+
and self.count.trace_state
76+
and not self.count.trace_state.is_valid()
77+
):
7478
raise errors.TraceContextError(
7579
f'Cannot mutate {type(self).__name__} from a different trace level'
7680
)

flax/nnx/transforms/iteration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import typing as tp
2020

2121
from flax import struct
22+
from flax import typing
2223
from flax.core.frozen_dict import FrozenDict
2324
from flax.nnx import extract, filterlib, graph, spmd, variablelib
2425
from flax.nnx import statelib
@@ -89,7 +90,7 @@ def axes(self) -> tuple[Index | type[Carry] | None, ...]:
8990
return self._axes
9091

9192
def map_prefix(
92-
self, path: variablelib.PathParts, variable: variablelib.Variable
93+
self, path: typing.PathParts, variable: variablelib.Variable
9394
) -> tp.Any:
9495
for filter, axis in zip(self.filters, self.axes):
9596
predicate = filterlib.to_predicate(filter)

0 commit comments

Comments
 (0)