Skip to content

Commit 0551fd4

Browse files
committed
deprecate Variable.value
1 parent 1522229 commit 0551fd4

30 files changed

+178
-122
lines changed

docs_nnx/index.rst

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,11 @@ Basic usage
107107
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
108108
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
109109

110-
@nnx.jit # automatic state management for JAX transforms
110+
@nnx.jit # automatic state propagation
111111
def train_step(model, optimizer, x, y):
112-
def loss_fn(model):
113-
y_pred = model(x) # call methods directly
114-
return ((y_pred - y) ** 2).mean()
115-
112+
loss_fn = lambda model: ((model(x) - y) ** 2).mean()
116113
loss, grads = nnx.value_and_grad(loss_fn)(model)
117114
optimizer.update(model, grads) # in-place updates
118-
119115
return loss
120116

121117

examples/gemma/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def assign_val_fn(
7474
mapped_path: tuple[str | int, ...],
7575
val: Any,
7676
) -> dict[tuple[str, ...], Any]:
77-
state[mapped_path].value = val
77+
state[mapped_path].set_value(val)
7878
return state
7979

8080
mdl: M = nnx.eval_shape(module_factory)

examples/gemma/helpers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]:
137137
np.testing.assert_array_equal(output, linen_output)
138138
for i in range(len(num_features)):
139139
np.testing.assert_array_equal(
140-
mdl.layers[i].layers[0].mean.value,
140+
mdl.layers[i].layers[0].mean[...],
141141
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'],
142142
)
143143
np.testing.assert_array_equal(
144-
mdl.layers[i].layers[0].var.value,
144+
mdl.layers[i].layers[0].var[...],
145145
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'],
146146
)
147147

examples/gemma/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def __init__(
4444
self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype))
4545

4646
def __call__(self, x: ArrayLike) -> Array:
47-
return jnp.einsum(self.einsum_str, x, self.w.value)
47+
return jnp.einsum(self.einsum_str, x, self.w[...])
4848

4949
@property
5050
def shape(self) -> Shape:
51-
return self.w.value.shape
51+
return self.w.shape
5252

5353

5454
class RMSNorm(nnx.Module):
@@ -65,12 +65,12 @@ def __init__(
6565
self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype))
6666

6767
def __call__(self, x: Array) -> Array:
68-
dtype = self.scale.value.dtype
68+
dtype = self.scale.dtype
6969
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
7070
normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype)
7171
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
7272
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
7373
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
74-
scale = jnp.expand_dims(self.scale.value, axis=range(len(x.shape) - 1))
74+
scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1))
7575
normed_inputs = normed_inputs * (1 + scale)
7676
return normed_inputs

examples/gemma/modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def encode(self, x: ArrayLike) -> Array:
6363
return x
6464

6565
def decode(self, x: ArrayLike) -> Array:
66-
return jnp.dot(x, self.input_embedding.value.T)
66+
return jnp.dot(x, self.input_embedding.T)
6767

6868
@property
6969
def embed_dim(self):
70-
return self.input_embedding.value.shape[1]
70+
return self.input_embedding.shape[1]
7171

7272
@property
7373
def num_embed(self):
74-
return self.input_embedding.value.shape[0]
74+
return self.input_embedding.shape[0]
7575

7676

7777
class Attention(nnx.Module):

examples/gemma/sampler_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def test_forbidden_tokens(self):
232232
transformer_config, rngs=nnx.Rngs(params=0)
233233
)
234234
# Pre-cook the embedding matrix so that the output is deterministic.
235-
transformer.embedder.input_embedding.value = jnp.eye(
235+
transformer.embedder.input_embedding.set_value(jnp.eye(
236236
vocab.GetPieceSize(), 32
237-
)
237+
))
238238
sampler = sampler_lib.Sampler(
239239
transformer=transformer,
240240
vocab=vocab,

examples/gemma/sow_lib.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,11 @@ def merge(self, decoding_step, layer: nnx.Module):
4949
if field.name.startswith('attn_'):
5050
step_value = getattr(
5151
layer.attn, field.name.replace('attn_', '')
52-
).value[0]
52+
)[0]
5353
elif field.name.startswith('mlp_'):
54-
step_value = getattr(layer.mlp, field.name.replace('mlp_', '')).value[
55-
0
56-
]
54+
step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0]
5755
else:
58-
step_value = getattr(layer, field.name).value[0]
56+
step_value = getattr(layer, field.name)[0]
5957
except AttributeError as exc:
6058
raise ValueError(
6159
f'Intermediate {field.name} is not in the step intermediates.'
@@ -93,7 +91,7 @@ def merge(self, decoding_step, transformer: nnx.Module):
9391
if self.embeddings is not None:
9492
try:
9593
self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set(
96-
transformer.embeddings.value[0][:, 0, ...]
94+
transformer.embeddings[0][:, 0, ...]
9795
)
9896
except AttributeError as exc:
9997
raise ValueError(

examples/gemma/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,10 @@ def _assign_linen_params_to_nnx_state(
487487
if 'gate_proj' in mapped_path:
488488
if transpose_gating_einsum:
489489
val = jnp.swapaxes(val, 1, 2)
490-
state[mapped_path].value = val[0]
491-
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
490+
state[mapped_path].set_value(val[0])
491+
state[mapped_path[:-2] + ('up_proj', 'kernel')].set_value(val[1])
492492
else:
493-
state[mapped_path].value = val
493+
state[mapped_path].set_value(val)
494494
return state
495495

496496

examples/gemma/transformer_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def test_sow_intermediates(self, sow_config):
461461

462462
if sow_config.embeddings:
463463
self.assertTrue(hasattr(transformer, 'embeddings'))
464-
embeddings = transformer.embeddings.value[0]
464+
embeddings = transformer.embeddings[0]
465465
self.assertEqual(
466466
embeddings.shape,
467467
(batch_size, sequence_length, config.embed_dim),
@@ -472,7 +472,7 @@ def test_sow_intermediates(self, sow_config):
472472
for layer in transformer.layers:
473473
if sow_config.rs_after_attention:
474474
self.assertTrue(hasattr(layer, 'rs_after_attention'))
475-
rs_after_attention = layer.rs_after_attention.value[0]
475+
rs_after_attention = layer.rs_after_attention[0]
476476
self.assertIsNotNone(rs_after_attention)
477477
self.assertEqual(
478478
rs_after_attention.shape,
@@ -482,7 +482,7 @@ def test_sow_intermediates(self, sow_config):
482482
self.assertFalse(hasattr(layer, 'rs_after_attention'))
483483
if sow_config.rs_after_ffw:
484484
self.assertTrue(hasattr(layer, 'rs_after_ffw'))
485-
rs_after_ffw = layer.rs_after_ffw.value[0]
485+
rs_after_ffw = layer.rs_after_ffw[0]
486486
self.assertIsNotNone(rs_after_ffw)
487487
self.assertEqual(
488488
rs_after_ffw.shape,
@@ -492,7 +492,7 @@ def test_sow_intermediates(self, sow_config):
492492
self.assertFalse(hasattr(layer, 'rs_after_ffw'))
493493
if sow_config.attn_logits_topk:
494494
self.assertTrue(hasattr(layer.attn, 'logits_topk_values'))
495-
attn_logits_topk_values = layer.attn.logits_topk_values.value[0]
495+
attn_logits_topk_values = layer.attn.logits_topk_values[0]
496496
self.assertIsNotNone(attn_logits_topk_values)
497497
self.assertEqual(
498498
attn_logits_topk_values.shape,
@@ -504,7 +504,7 @@ def test_sow_intermediates(self, sow_config):
504504
),
505505
)
506506
self.assertTrue(hasattr(layer.attn, 'logits_topk_indices'))
507-
attn_logits_topk_indices = layer.attn.logits_topk_indices.value[0]
507+
attn_logits_topk_indices = layer.attn.logits_topk_indices[0]
508508
self.assertIsNotNone(attn_logits_topk_indices)
509509
self.assertEqual(
510510
attn_logits_topk_indices.shape,
@@ -520,7 +520,7 @@ def test_sow_intermediates(self, sow_config):
520520
self.assertFalse(hasattr(layer.attn, 'logits_topk_indices'))
521521
if sow_config.mlp_hidden_topk:
522522
self.assertTrue(hasattr(layer.mlp, 'hidden_topk_values'))
523-
ffw_hidden_topk_values = layer.mlp.hidden_topk_values.value[0]
523+
ffw_hidden_topk_values = layer.mlp.hidden_topk_values[0]
524524
self.assertIsNotNone(ffw_hidden_topk_values)
525525
self.assertEqual(
526526
ffw_hidden_topk_values.shape,
@@ -531,7 +531,7 @@ def test_sow_intermediates(self, sow_config):
531531
),
532532
)
533533
self.assertTrue(hasattr(layer.mlp, 'hidden_topk_indices'))
534-
ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices.value[0]
534+
ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices[0]
535535
self.assertIsNotNone(ffw_hidden_topk_indices)
536536
self.assertEqual(
537537
ffw_hidden_topk_indices.shape,

flax/configurations.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,38 @@ def static_bool_env(varname: str, default: bool) -> bool:
201201
)
202202

203203

204+
def str_flag(name: str, *, default: str, help: str) -> FlagHolder[str]:
205+
"""Set up a string flag.
206+
207+
Example::
208+
209+
some_string = str_flag(
210+
name='flax_some_string',
211+
default='default_value',
212+
help='Some string configuration.',
213+
)
214+
215+
Now the ``FLAX_SOME_STRING`` shell environment variable can be used to
216+
control the process-level value of the flag, in addition to using e.g.
217+
``config.update("flax_some_string", "new_value")`` directly.
218+
219+
Args:
220+
name: converted to lowercase to define the name of the flag. It is
221+
converted to uppercase to define the corresponding shell environment
222+
variable.
223+
default: a default value for the flag.
224+
help: used to populate the docstring of the returned flag holder object.
225+
226+
Returns:
227+
A flag holder object for accessing the value of the flag.
228+
"""
229+
name = name.lower()
230+
config._add_option(name, static_str_env(name.upper(), default))
231+
fh = FlagHolder[str](name, help)
232+
setattr(Config, name, property(lambda _: fh.value, doc=help))
233+
return fh
234+
235+
204236
def static_int_env(varname: str, default: int | None) -> int | None:
205237
"""Read an environment variable and interpret it as an integer.
206238
@@ -222,6 +254,18 @@ def static_int_env(varname: str, default: int | None) -> int | None:
222254
) from None
223255

224256

257+
def static_str_env(varname: str, default: str) -> str:
258+
"""Read an environment variable and interpret it as a string.
259+
260+
Args:
261+
varname: the name of the variable
262+
default: the default string value
263+
Returns:
264+
string return value derived from defaults and environment.
265+
"""
266+
return os.getenv(varname, default)
267+
268+
225269
# Flax Global Configuration Variables:
226270

227271
flax_filter_frames = bool_flag(
@@ -294,5 +338,5 @@ def static_int_env(varname: str, default: int | None) -> int | None:
294338
flax_hijax_variable = bool_flag(
295339
name='flax_hijax_variable',
296340
default=False,
297-
help='Whether to enable HiJAX support for `nnx.Variable`.',
341+
help='Whether to use hijax for `nnx.Variable`. Options are "pytree", "hijax", and "ref".',
298342
)

0 commit comments

Comments
 (0)