2020import threading
2121import typing as tp
2222
23- import jax .experimental
23+ import jax .core
2424
2525from flax import config
2626from flax .nnx import filterlib , reprlib , traversals , variablelib
@@ -373,17 +373,13 @@ class VariableDef(reprlib.Representable, tp.Generic[Node]):
373373 index : int
374374 outer_index : int | None
375375 metadata : HashableMapping [str , tp .Any ]
376- array_refdef : ArrayRefDef | NodeRef | None
377376
378377 def with_no_outer_index (self ) -> VariableDef :
379378 return VariableDef (
380379 type = self .type ,
381380 index = self .index ,
382381 outer_index = None ,
383382 metadata = self .metadata ,
384- array_refdef = self .array_refdef .with_no_outer_index ()
385- if isinstance (self .array_refdef , ArrayRefDef )
386- else self .array_refdef ,
387383 )
388384
389385 def with_same_outer_index (self ) -> VariableDef :
@@ -392,9 +388,6 @@ def with_same_outer_index(self) -> VariableDef:
392388 index = self .index ,
393389 outer_index = self .index ,
394390 metadata = self .metadata ,
395- array_refdef = self .array_refdef .with_same_outer_index ()
396- if isinstance (self .array_refdef , ArrayRefDef )
397- else self .array_refdef ,
398391 )
399392
400393 def __nnx_repr__ (self ):
@@ -761,32 +754,23 @@ def make_mutable_arraydef(value: variablelib.Ref):
761754 if is_variable :
762755 assert isinstance (node , Variable )
763756 assert index is not None
764- prev_inner_value = node .raw_value
765- if variablelib .is_array_ref (prev_inner_value ):
766- array_refdef , inner_value = make_mutable_arraydef (prev_inner_value )
767- else :
768- array_refdef = None
769- inner_value = prev_inner_value
770757 if path is None :
771- leaf = inner_value
758+ leaf = node . raw_value
772759 else :
773760 leaf = node # type: ignore[assignment]
774- if inner_value is not prev_inner_value :
775- leaf .raw_value = inner_value
776761
777762 variabledef = VariableDef (
778- type = type (node ),
763+ type = jax .typeof (node )._var_type # type: ignore
764+ if isinstance (node , jax .core .Tracer )
765+ else type (node ),
779766 index = index ,
780767 outer_index = ref_outer_index .get (node , None ) if ref_outer_index else None ,
781768 metadata = HashableMapping (node .get_metadata ()),
782- array_refdef = array_refdef ,
783769 )
784- if type (inner_value ) is not Repeated :
785- assert not isinstance (leaf , Repeated )
786- leaves .append (leaf )
787- if path is not None :
788- assert paths is not None
789- paths .append (tuple (path ))
770+ leaves .append (leaf )
771+ if path is not None :
772+ assert paths is not None
773+ paths .append (tuple (path ))
790774 nodes .append (variabledef )
791775 return
792776 elif is_array_ref :
@@ -1200,7 +1184,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12001184 f"Expected a ArrayRefOutput type but got '{ leaf .value } .'"
12011185 )
12021186 elif type (leaf ) is ArrayRefOutput :
1203- array_ref = variablelib .new_ref (leaf .value )
1187+ array_ref = jax .new_ref (leaf .value )
12041188 elif variablelib .is_array_ref (leaf ):
12051189 array_ref = leaf
12061190 else :
@@ -1217,26 +1201,9 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12171201 variabledef = tp .cast (VariableDef [Variable ], nodedef )
12181202 # its a unseen variable, create a new one
12191203
1220- if variabledef .array_refdef is not None :
1221- if type (variabledef .array_refdef ) is NodeRef :
1222- value = index_ref [variabledef .array_refdef .index ]
1223- else :
1224- value = next (leaves_iter )
1225- assert type (variabledef .array_refdef ) is ArrayRefDef
1226- if isinstance (value , Variable ):
1227- value = value .copy () if copy_variables else value
1228- inner_value = value .raw_value
1229- array_ref = get_mutable_array (variabledef .array_refdef , inner_value )
1230- if array_ref is not inner_value :
1231- value .raw_value = array_ref
1232- else :
1233- # if value is an array or array ref, we need call get_mutable_array
1234- # to register it in the index_ref
1235- value = get_mutable_array (variabledef .array_refdef , value )
1236- else :
1237- value = next (leaves_iter )
1238- if isinstance (value , Variable ) and copy_variables :
1239- value = value .copy ()
1204+ value = next (leaves_iter )
1205+ if isinstance (value , Variable ) and copy_variables :
1206+ value = value .copy ()
12401207
12411208 # when idxmap is present, check if the Varable exists there
12421209 # and update existing variables if it does
@@ -1442,7 +1409,7 @@ def _update_variable(node: Variable, value):
14421409 ):
14431410 node [...] = value [...]
14441411 else :
1445- node .raw_value = value
1412+ node .set_raw_value ( value )
14461413
14471414 if isinstance (node , Variable ):
14481415 _update_variable (node , state )
@@ -2616,7 +2583,7 @@ def clone(node: Node, variables: bool = True) -> Node:
26162583
26172584
26182585def _mutable_like (path , x ):
2619- return ( isinstance ( x , Variable ) and x . has_ref ) or variablelib .is_array_ref (x )
2586+ return variablelib .is_array_ref (x )
26202587
26212588
26222589def to_arrays (
@@ -2669,7 +2636,9 @@ def to_arrays(
26692636 Returns:
26702637 A structure with the frozen arrays.
26712638 """
2672- if not allow_duplicates and (all_duplicates := find_duplicates (node , only = only )):
2639+ if not allow_duplicates and (
2640+ all_duplicates := find_duplicates (node , only = only )
2641+ ):
26732642 duplicates_strs = '\n ---'
26742643 for node_duplicates in all_duplicates :
26752644 for path in node_duplicates :
@@ -2685,7 +2654,7 @@ def to_arrays(
26852654
26862655
26872656def _array_like (path , x ):
2688- return ( isinstance ( x , Variable ) and not x . has_ref ) or isinstance (x , jax .Array )
2657+ return isinstance (x , jax .Array )
26892658
26902659
26912660def to_refs (node : A , / , only : filterlib .Filter = _array_like ) -> A :
@@ -2741,13 +2710,13 @@ def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A:
27412710 raise ValueError (f'Found duplicate at paths:{ duplicates_strs } ' )
27422711
27432712 graphdef , frozen_state , rest = split (node , only , ...) # type: ignore[misc]
2744- mutable_state = jax .tree .map (variablelib .new_ref , frozen_state )
2713+ mutable_state = jax .tree .map (jax .new_ref , frozen_state )
27452714 node = merge (graphdef , mutable_state , rest )
27462715 return node
27472716
27482717def _is_lojax_variable (path , x ):
27492718 return isinstance (x , variablelib .Variable ) and not isinstance (
2750- x , variablelib .MutableHijaxVariable
2719+ x , variablelib .HijaxVariable
27512720 )
27522721
27532722
@@ -2789,7 +2758,7 @@ def _to_stateful(x):
27892758
27902759
27912760def _is_hijax_variable (path , x ):
2792- return isinstance (x , variablelib .MutableHijaxVariable )
2761+ return isinstance (x , variablelib .HijaxVariable )
27932762
27942763
27952764def to_lojax (node : A , / , only : filterlib .Filter = ...) -> A :
@@ -2808,7 +2777,7 @@ def to_lojax(node: A, /, only: filterlib.Filter = ...) -> A:
28082777 def _to_stateless (x ):
28092778 if variablelib .is_array_ref (x ):
28102779 return x [...]
2811- elif isinstance (x , variablelib .MutableHijaxVariable ):
2780+ elif isinstance (x , variablelib .HijaxVariable ):
28122781 return variablelib ._get_mutable_hijax_state (x )
28132782 return x
28142783
@@ -2820,6 +2789,75 @@ def _to_stateless(x):
28202789 return node
28212790
28222791
2792+ def _is_lojax_variable (path , x ):
2793+ return isinstance (x , variablelib .Variable ) and not isinstance (
2794+ x , variablelib .HijaxVariable
2795+ )
2796+
2797+
2798+ def to_hijax (
2799+ node : A , / , * , only : filterlib .Filter = ..., mutable : bool = True
2800+ ) -> A :
2801+ """ """
2802+ if not mutable :
2803+ raise ValueError ('to_hijax only supports mutable=True at the moment.' )
2804+
2805+ only = filterlib .All (_is_lojax_variable , only )
2806+ predicate = filterlib .to_predicate (only )
2807+
2808+ if all_duplicates := find_duplicates (node , only = only ):
2809+ duplicates_strs = '\n ---'
2810+ for node_duplicates in all_duplicates :
2811+ for path in node_duplicates :
2812+ path_str = '/' .join (map (str , path ))
2813+ duplicates_strs += f'\n { path_str } '
2814+ duplicates_strs += '\n ---'
2815+ raise ValueError (f'Found duplicate at paths:{ duplicates_strs } ' )
2816+
2817+ def _to_hijax (jax_path , x ):
2818+ if predicate (to_nnx_path (jax_path ), x ):
2819+ assert isinstance (x , variablelib .Variable )
2820+ x = x .copy ()
2821+ x ._var_metadata ['is_hijax' ] = True
2822+ return variablelib ._new_mutable_hijax_from_variable (x )
2823+ return x
2824+
2825+ node = jax .tree .map_with_path (
2826+ _to_hijax , node , is_leaf = lambda x : isinstance (x , variablelib .Variable )
2827+ )
2828+ return node
2829+
2830+
2831+ def _is_hijax_variable (path , x ):
2832+ return isinstance (x , variablelib .HijaxVariable )
2833+
2834+ def to_lojax (node : A , / , only : filterlib .Filter = ...) -> A :
2835+ """ """
2836+ only = filterlib .All (_is_hijax_variable , only )
2837+ predicate = filterlib .to_predicate (only )
2838+
2839+ if all_duplicates := find_duplicates (node , only = only ):
2840+ duplicates_strs = '\n ---'
2841+ for node_duplicates in all_duplicates :
2842+ for path in node_duplicates :
2843+ path_str = '/' .join (map (str , path ))
2844+ duplicates_strs += f'\n { path_str } '
2845+ duplicates_strs += '\n ---'
2846+ raise ValueError (f'Found duplicate at paths:{ duplicates_strs } ' )
2847+
2848+ def _to_lojax (jax_path , x ):
2849+ if predicate (to_nnx_path (jax_path ), x ):
2850+ variable = variablelib ._get_mutable_hijax_state (x )
2851+ variable ._var_metadata ['is_hijax' ] = False
2852+ return variable
2853+ return x
2854+
2855+ node = jax .tree .map_with_path (
2856+ _to_lojax , node , is_leaf = lambda x : isinstance (x , variablelib .Variable )
2857+ )
2858+ return node
2859+
2860+
28232861def pure (tree : A ) -> A :
28242862 """Returns a new tree with all ``Variable`` objects replaced with inner values.
28252863
@@ -3144,7 +3182,7 @@ def _key_path_to_key(key: tp.Any) -> Key:
31443182 return str (key )
31453183
31463184
3147- def jax_to_nnx_path (jax_path : tuple , / ):
3185+ def to_nnx_path (jax_path : tuple , / ):
31483186 return tuple (_key_path_to_key (part ) for part in jax_path )
31493187
31503188
0 commit comments