Skip to content

Commit b200679

Browse files
committed
Merge branch 'remove-variable-ref' into hijax-variable
2 parents dff0ff6 + 579025b commit b200679

File tree

18 files changed

+455
-512
lines changed

18 files changed

+455
-512
lines changed

docs_nnx/guides/array_ref.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@
123123
}
124124
],
125125
"source": [
126-
"variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)\n",
127-
"print(f\"{variable.has_ref = }\\n\")\n",
126+
"variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True)\n",
127+
"print(f\"{variable.is_hijax = }\\n\")\n",
128128
"\n",
129129
"print(\"[1] =\", variable); increment(variable); print(\"[2] =\", variable)"
130130
]
@@ -147,7 +147,7 @@
147147
"with nnx.use_refs(True):\n",
148148
" variable = nnx.Variable(jnp.array([1, 2, 3]))\n",
149149
"\n",
150-
"print(f\"{variable.has_ref = }\")"
150+
"print(f\"{variable.is_hijax = }\")"
151151
]
152152
},
153153
{

docs_nnx/guides/array_ref.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ print(increment.lower(a_ref).as_text())
4545
### Variables Refs
4646

4747
```{code-cell} ipython3
48-
variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)
49-
print(f"{variable.has_ref = }\n")
48+
variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True)
49+
print(f"{variable.is_hijax = }\n")
5050
5151
print("[1] =", variable); increment(variable); print("[2] =", variable)
5252
```
@@ -55,7 +55,7 @@ print("[1] =", variable); increment(variable); print("[2] =", variable)
5555
with nnx.use_refs(True):
5656
variable = nnx.Variable(jnp.array([1, 2, 3]))
5757
58-
print(f"{variable.has_ref = }")
58+
print(f"{variable.is_hijax = }")
5959
```
6060

6161
Mention `nnx.use_refs` can be used as global flag

flax/nnx/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@
193193
from .variablelib import variable_type_from_name as variable_type_from_name
194194
from .variablelib import variable_name_from_type as variable_name_from_type
195195
from .variablelib import register_variable_name as register_variable_name
196-
from .variablelib import use_refs as use_refs
197-
from .variablelib import using_refs as using_refs
198196
from .variablelib import use_hijax as use_hijax
199197
from .variablelib import using_hijax as using_hijax
200198
from .visualization import display as display

flax/nnx/bridge/module.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,9 @@ def _get_variables(self) -> tp.Mapping:
388388
if collection not in _variables:
389389
_variables[collection] = {}
390390

391-
if (
392-
isinstance(variable, variablelib.Variable)
393-
and not variable.get_metadata()
394-
):
391+
if isinstance(
392+
variable, variablelib.Variable
393+
) and bridge_variables.is_vanilla_variable(variable):
395394
leaf = variable.value
396395
else:
397396
leaf = bridge_variables.to_linen_var(variable)

flax/nnx/bridge/variables.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def is_vanilla_variable(vs: variablelib.Variable) -> bool:
7878
Returns False only if it has non-empty hooks or any non-built-in attribute.
7979
"""
8080
for key, value in vs.get_metadata().items():
81-
if key.endswith('_hooks'):
82-
if value != ():
83-
return False
84-
else:
85-
return False
81+
if key in ('is_hijax', 'eager_sharding'):
82+
continue
83+
if key.endswith('_hooks') and value == ():
84+
continue
85+
return False
8686
return True
8787

8888

flax/nnx/graph.py

Lines changed: 93 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
import typing as tp
2222

23-
import jax.experimental
23+
import jax.core
2424

2525
from flax import config
2626
from flax.nnx import filterlib, reprlib, traversals, variablelib
@@ -373,17 +373,13 @@ class VariableDef(reprlib.Representable, tp.Generic[Node]):
373373
index: int
374374
outer_index: int | None
375375
metadata: HashableMapping[str, tp.Any]
376-
array_refdef: ArrayRefDef | NodeRef | None
377376

378377
def with_no_outer_index(self) -> VariableDef:
379378
return VariableDef(
380379
type=self.type,
381380
index=self.index,
382381
outer_index=None,
383382
metadata=self.metadata,
384-
array_refdef=self.array_refdef.with_no_outer_index()
385-
if isinstance(self.array_refdef, ArrayRefDef)
386-
else self.array_refdef,
387383
)
388384

389385
def with_same_outer_index(self) -> VariableDef:
@@ -392,9 +388,6 @@ def with_same_outer_index(self) -> VariableDef:
392388
index=self.index,
393389
outer_index=self.index,
394390
metadata=self.metadata,
395-
array_refdef=self.array_refdef.with_same_outer_index()
396-
if isinstance(self.array_refdef, ArrayRefDef)
397-
else self.array_refdef,
398391
)
399392

400393
def __nnx_repr__(self):
@@ -761,32 +754,23 @@ def make_mutable_arraydef(value: variablelib.Ref):
761754
if is_variable:
762755
assert isinstance(node, Variable)
763756
assert index is not None
764-
prev_inner_value = node.raw_value
765-
if variablelib.is_array_ref(prev_inner_value):
766-
array_refdef, inner_value = make_mutable_arraydef(prev_inner_value)
767-
else:
768-
array_refdef = None
769-
inner_value = prev_inner_value
770757
if path is None:
771-
leaf = inner_value
758+
leaf = node.raw_value
772759
else:
773760
leaf = node # type: ignore[assignment]
774-
if inner_value is not prev_inner_value:
775-
leaf.raw_value = inner_value
776761

777762
variabledef = VariableDef(
778-
type=type(node),
763+
type=jax.typeof(node)._var_type # type: ignore
764+
if isinstance(node, jax.core.Tracer)
765+
else type(node),
779766
index=index,
780767
outer_index=ref_outer_index.get(node, None) if ref_outer_index else None,
781768
metadata=HashableMapping(node.get_metadata()),
782-
array_refdef=array_refdef,
783769
)
784-
if type(inner_value) is not Repeated:
785-
assert not isinstance(leaf, Repeated)
786-
leaves.append(leaf)
787-
if path is not None:
788-
assert paths is not None
789-
paths.append(tuple(path))
770+
leaves.append(leaf)
771+
if path is not None:
772+
assert paths is not None
773+
paths.append(tuple(path))
790774
nodes.append(variabledef)
791775
return
792776
elif is_array_ref:
@@ -1200,7 +1184,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12001184
f"Expected a ArrayRefOutput type but got '{leaf.value}.'"
12011185
)
12021186
elif type(leaf) is ArrayRefOutput:
1203-
array_ref = variablelib.new_ref(leaf.value)
1187+
array_ref = jax.new_ref(leaf.value)
12041188
elif variablelib.is_array_ref(leaf):
12051189
array_ref = leaf
12061190
else:
@@ -1217,26 +1201,9 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12171201
variabledef = tp.cast(VariableDef[Variable], nodedef)
12181202
# its a unseen variable, create a new one
12191203

1220-
if variabledef.array_refdef is not None:
1221-
if type(variabledef.array_refdef) is NodeRef:
1222-
value = index_ref[variabledef.array_refdef.index]
1223-
else:
1224-
value = next(leaves_iter)
1225-
assert type(variabledef.array_refdef) is ArrayRefDef
1226-
if isinstance(value, Variable):
1227-
value = value.copy() if copy_variables else value
1228-
inner_value = value.raw_value
1229-
array_ref = get_mutable_array(variabledef.array_refdef, inner_value)
1230-
if array_ref is not inner_value:
1231-
value.raw_value = array_ref
1232-
else:
1233-
# if value is an array or array ref, we need call get_mutable_array
1234-
# to register it in the index_ref
1235-
value = get_mutable_array(variabledef.array_refdef, value)
1236-
else:
1237-
value = next(leaves_iter)
1238-
if isinstance(value, Variable) and copy_variables:
1239-
value = value.copy()
1204+
value = next(leaves_iter)
1205+
if isinstance(value, Variable) and copy_variables:
1206+
value = value.copy()
12401207

12411208
# when idxmap is present, check if the Varable exists there
12421209
# and update existing variables if it does
@@ -1442,7 +1409,7 @@ def _update_variable(node: Variable, value):
14421409
):
14431410
node[...] = value[...]
14441411
else:
1445-
node.raw_value = value
1412+
node.set_raw_value(value)
14461413

14471414
if isinstance(node, Variable):
14481415
_update_variable(node, state)
@@ -2616,7 +2583,7 @@ def clone(node: Node, variables: bool = True) -> Node:
26162583

26172584

26182585
def _mutable_like(path, x):
2619-
return (isinstance(x, Variable) and x.has_ref) or variablelib.is_array_ref(x)
2586+
return variablelib.is_array_ref(x)
26202587

26212588

26222589
def to_arrays(
@@ -2669,7 +2636,9 @@ def to_arrays(
26692636
Returns:
26702637
A structure with the frozen arrays.
26712638
"""
2672-
if not allow_duplicates and (all_duplicates := find_duplicates(node, only=only)):
2639+
if not allow_duplicates and (
2640+
all_duplicates := find_duplicates(node, only=only)
2641+
):
26732642
duplicates_strs = '\n ---'
26742643
for node_duplicates in all_duplicates:
26752644
for path in node_duplicates:
@@ -2685,7 +2654,7 @@ def to_arrays(
26852654

26862655

26872656
def _array_like(path, x):
2688-
return (isinstance(x, Variable) and not x.has_ref) or isinstance(x, jax.Array)
2657+
return isinstance(x, jax.Array)
26892658

26902659

26912660
def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A:
@@ -2741,13 +2710,13 @@ def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A:
27412710
raise ValueError(f'Found duplicate at paths:{duplicates_strs}')
27422711

27432712
graphdef, frozen_state, rest = split(node, only, ...) # type: ignore[misc]
2744-
mutable_state = jax.tree.map(variablelib.new_ref, frozen_state)
2713+
mutable_state = jax.tree.map(jax.new_ref, frozen_state)
27452714
node = merge(graphdef, mutable_state, rest)
27462715
return node
27472716

27482717
def _is_lojax_variable(path, x):
27492718
return isinstance(x, variablelib.Variable) and not isinstance(
2750-
x, variablelib.MutableHijaxVariable
2719+
x, variablelib.HijaxVariable
27512720
)
27522721

27532722

@@ -2789,7 +2758,7 @@ def _to_stateful(x):
27892758

27902759

27912760
def _is_hijax_variable(path, x):
2792-
return isinstance(x, variablelib.MutableHijaxVariable)
2761+
return isinstance(x, variablelib.HijaxVariable)
27932762

27942763

27952764
def to_lojax(node: A, /, only: filterlib.Filter = ...) -> A:
@@ -2808,7 +2777,7 @@ def to_lojax(node: A, /, only: filterlib.Filter = ...) -> A:
28082777
def _to_stateless(x):
28092778
if variablelib.is_array_ref(x):
28102779
return x[...]
2811-
elif isinstance(x, variablelib.MutableHijaxVariable):
2780+
elif isinstance(x, variablelib.HijaxVariable):
28122781
return variablelib._get_mutable_hijax_state(x)
28132782
return x
28142783

@@ -2820,6 +2789,75 @@ def _to_stateless(x):
28202789
return node
28212790

28222791

2792+
def _is_lojax_variable(path, x):
2793+
return isinstance(x, variablelib.Variable) and not isinstance(
2794+
x, variablelib.HijaxVariable
2795+
)
2796+
2797+
2798+
def to_hijax(
2799+
node: A, /, *, only: filterlib.Filter = ..., mutable: bool = True
2800+
) -> A:
2801+
""" """
2802+
if not mutable:
2803+
raise ValueError('to_hijax only supports mutable=True at the moment.')
2804+
2805+
only = filterlib.All(_is_lojax_variable, only)
2806+
predicate = filterlib.to_predicate(only)
2807+
2808+
if all_duplicates := find_duplicates(node, only=only):
2809+
duplicates_strs = '\n ---'
2810+
for node_duplicates in all_duplicates:
2811+
for path in node_duplicates:
2812+
path_str = '/'.join(map(str, path))
2813+
duplicates_strs += f'\n {path_str}'
2814+
duplicates_strs += '\n ---'
2815+
raise ValueError(f'Found duplicate at paths:{duplicates_strs}')
2816+
2817+
def _to_hijax(jax_path, x):
2818+
if predicate(to_nnx_path(jax_path), x):
2819+
assert isinstance(x, variablelib.Variable)
2820+
x = x.copy()
2821+
x._var_metadata['is_hijax'] = True
2822+
return variablelib._new_mutable_hijax_from_variable(x)
2823+
return x
2824+
2825+
node = jax.tree.map_with_path(
2826+
_to_hijax, node, is_leaf=lambda x: isinstance(x, variablelib.Variable)
2827+
)
2828+
return node
2829+
2830+
2831+
def _is_hijax_variable(path, x):
2832+
return isinstance(x, variablelib.HijaxVariable)
2833+
2834+
def to_lojax(node: A, /, only: filterlib.Filter = ...) -> A:
2835+
""" """
2836+
only = filterlib.All(_is_hijax_variable, only)
2837+
predicate = filterlib.to_predicate(only)
2838+
2839+
if all_duplicates := find_duplicates(node, only=only):
2840+
duplicates_strs = '\n ---'
2841+
for node_duplicates in all_duplicates:
2842+
for path in node_duplicates:
2843+
path_str = '/'.join(map(str, path))
2844+
duplicates_strs += f'\n {path_str}'
2845+
duplicates_strs += '\n ---'
2846+
raise ValueError(f'Found duplicate at paths:{duplicates_strs}')
2847+
2848+
def _to_lojax(jax_path, x):
2849+
if predicate(to_nnx_path(jax_path), x):
2850+
variable = variablelib._get_mutable_hijax_state(x)
2851+
variable._var_metadata['is_hijax'] = False
2852+
return variable
2853+
return x
2854+
2855+
node = jax.tree.map_with_path(
2856+
_to_lojax, node, is_leaf=lambda x: isinstance(x, variablelib.Variable)
2857+
)
2858+
return node
2859+
2860+
28232861
def pure(tree: A) -> A:
28242862
"""Returns a new tree with all ``Variable`` objects replaced with inner values.
28252863
@@ -3144,7 +3182,7 @@ def _key_path_to_key(key: tp.Any) -> Key:
31443182
return str(key)
31453183

31463184

3147-
def jax_to_nnx_path(jax_path: tuple, /):
3185+
def to_nnx_path(jax_path: tuple, /):
31483186
return tuple(_key_path_to_key(part) for part in jax_path)
31493187

31503188

flax/nnx/nn/normalization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _compute_stats(
5454
use_mean: bool = True,
5555
use_fast_variance: bool = True,
5656
mask: tp.Optional[Array] = None,
57-
):
57+
) -> tuple[Array, Array]:
5858
"""Computes mean and variance statistics.
5959
6060
This implementation takes care of a few important details:
@@ -357,6 +357,8 @@ def __call__(
357357
feature_axes = _canonicalize_axes(x.ndim, self.axis)
358358
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
359359

360+
mean: jax.Array
361+
var: jax.Array
360362
if use_running_average:
361363
mean, var = self.mean.value, self.var.value
362364
else:
@@ -370,7 +372,7 @@ def __call__(
370372
mask=mask,
371373
)
372374
# stop_gradient only for flax_array_ref
373-
if self.mean.has_ref or self.var.has_ref:
375+
if self.mean.is_hijax or self.var.is_hijax:
374376
stop_gradient = jax.lax.stop_gradient
375377
else:
376378
stop_gradient = lambda x: x

0 commit comments

Comments
 (0)