Skip to content

Commit cce5921

Browse files
committed
deprecate .value
1 parent cd37bc9 commit cce5921

32 files changed

+254
-158
lines changed

docs_nnx/index.rst

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,11 @@ Basic usage
107107
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
108108
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
109109

110-
@nnx.jit # automatic state management for JAX transforms
111-
def train_step(model, optimizer, x, y, rngs):
112-
def loss_fn(model):
113-
y_pred = model(x, rngs) # call methods directly
114-
return ((y_pred - y) ** 2).mean()
115-
110+
@nnx.jit # automatic state propagation
111+
def train_step(model, optimizer, x, y):
112+
loss_fn = lambda model: ((model(x) - y) ** 2).mean()
116113
loss, grads = nnx.value_and_grad(loss_fn)(model)
117114
optimizer.update(model, grads) # in-place updates
118-
119115
return loss
120116

121117

docs_nnx/nnx_basics.ipynb

Lines changed: 27 additions & 27 deletions
Large diffs are not rendered by default.

docs_nnx/nnx_basics.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Linear(nnx.Module):
4040
self.din, self.dout = din, dout
4141
4242
def __call__(self, x: jax.Array):
43-
return x @ self.w + self.b
43+
return x @ self.w + self.b[None]
4444
```
4545

4646
Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).
@@ -73,12 +73,12 @@ class Counter(nnx.Module):
7373
self.count = Count(jnp.array(0))
7474
7575
def __call__(self):
76-
self.count.value += 1
76+
self.count[...] += 1
7777
7878
counter = Counter()
79-
print(f'{counter.count.value = }')
79+
print(f'{counter.count[...] = }')
8080
counter()
81-
print(f'{counter.count.value = }')
81+
print(f'{counter.count[...] = }')
8282
```
8383

8484
Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms

examples/gemma/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def assign_val_fn(
7474
mapped_path: tuple[str | int, ...],
7575
val: Any,
7676
) -> dict[tuple[str, ...], Any]:
77-
state[mapped_path].value = val
77+
state[mapped_path].set_value(val)
7878
return state
7979

8080
mdl: M = nnx.eval_shape(module_factory)

examples/gemma/helpers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]:
137137
np.testing.assert_array_equal(output, linen_output)
138138
for i in range(len(num_features)):
139139
np.testing.assert_array_equal(
140-
mdl.layers[i].layers[0].mean.value,
140+
mdl.layers[i].layers[0].mean[...],
141141
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'],
142142
)
143143
np.testing.assert_array_equal(
144-
mdl.layers[i].layers[0].var.value,
144+
mdl.layers[i].layers[0].var[...],
145145
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'],
146146
)
147147

examples/gemma/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def __init__(
4444
self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype))
4545

4646
def __call__(self, x: ArrayLike) -> Array:
47-
return jnp.einsum(self.einsum_str, x, self.w.value)
47+
return jnp.einsum(self.einsum_str, x, self.w[...])
4848

4949
@property
5050
def shape(self) -> Shape:
51-
return self.w.value.shape
51+
return self.w.shape
5252

5353

5454
class RMSNorm(nnx.Module):
@@ -65,12 +65,12 @@ def __init__(
6565
self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype))
6666

6767
def __call__(self, x: Array) -> Array:
68-
dtype = self.scale.value.dtype
68+
dtype = self.scale.dtype
6969
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
7070
normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype)
7171
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
7272
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
7373
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
74-
scale = jnp.expand_dims(self.scale.value, axis=range(len(x.shape) - 1))
74+
scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1))
7575
normed_inputs = normed_inputs * (1 + scale)
7676
return normed_inputs

examples/gemma/modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def encode(self, x: ArrayLike) -> Array:
6363
return x
6464

6565
def decode(self, x: ArrayLike) -> Array:
66-
return jnp.dot(x, self.input_embedding.value.T)
66+
return jnp.dot(x, self.input_embedding.T)
6767

6868
@property
6969
def embed_dim(self):
70-
return self.input_embedding.value.shape[1]
70+
return self.input_embedding.shape[1]
7171

7272
@property
7373
def num_embed(self):
74-
return self.input_embedding.value.shape[0]
74+
return self.input_embedding.shape[0]
7575

7676

7777
class Attention(nnx.Module):

examples/gemma/sampler_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def test_forbidden_tokens(self):
232232
transformer_config, rngs=nnx.Rngs(params=0)
233233
)
234234
# Pre-cook the embedding matrix so that the output is deterministic.
235-
transformer.embedder.input_embedding.value = jnp.eye(
235+
transformer.embedder.input_embedding.set_value(jnp.eye(
236236
vocab.GetPieceSize(), 32
237-
)
237+
))
238238
sampler = sampler_lib.Sampler(
239239
transformer=transformer,
240240
vocab=vocab,

examples/gemma/sow_lib.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,11 @@ def merge(self, decoding_step, layer: nnx.Module):
4949
if field.name.startswith('attn_'):
5050
step_value = getattr(
5151
layer.attn, field.name.replace('attn_', '')
52-
).value[0]
52+
)[0]
5353
elif field.name.startswith('mlp_'):
54-
step_value = getattr(layer.mlp, field.name.replace('mlp_', '')).value[
55-
0
56-
]
54+
step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0]
5755
else:
58-
step_value = getattr(layer, field.name).value[0]
56+
step_value = getattr(layer, field.name)[0]
5957
except AttributeError as exc:
6058
raise ValueError(
6159
f'Intermediate {field.name} is not in the step intermediates.'
@@ -93,7 +91,7 @@ def merge(self, decoding_step, transformer: nnx.Module):
9391
if self.embeddings is not None:
9492
try:
9593
self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set(
96-
transformer.embeddings.value[0][:, 0, ...]
94+
transformer.embeddings[0][:, 0, ...]
9795
)
9896
except AttributeError as exc:
9997
raise ValueError(

examples/gemma/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,10 @@ def _assign_linen_params_to_nnx_state(
487487
if 'gate_proj' in mapped_path:
488488
if transpose_gating_einsum:
489489
val = jnp.swapaxes(val, 1, 2)
490-
state[mapped_path].value = val[0]
491-
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
490+
state[mapped_path].set_value(val[0])
491+
state[mapped_path[:-2] + ('up_proj', 'kernel')].set_value(val[1])
492492
else:
493-
state[mapped_path].value = val
493+
state[mapped_path].set_value(val)
494494
return state
495495

496496

0 commit comments

Comments
 (0)