diff --git a/docs_nnx/api_reference/flax.nnx/graph.rst b/docs_nnx/api_reference/flax.nnx/graph.rst index 867525583..da9fe2f16 100644 --- a/docs_nnx/api_reference/flax.nnx/graph.rst +++ b/docs_nnx/api_reference/flax.nnx/graph.rst @@ -31,7 +31,12 @@ graph .. autofunction:: find_duplicates .. autofunction:: pure -.. autofunction:: to_refs -.. autofunction:: to_arrays +.. autofunction:: immutable +.. autofunction:: mutable +.. autofunction:: as_hijax_vars +.. autofunction:: as_pytree_vars +.. autofunction:: as_ref_vars +.. autofunction:: as_array_vars +.. autofunction:: as_pytree_vars .. autofunction:: flatten .. autofunction:: unflatten diff --git a/docs_nnx/guides/array_ref.ipynb b/docs_nnx/guides/array_ref.ipynb deleted file mode 100644 index 1df71c629..000000000 --- a/docs_nnx/guides/array_ref.ipynb +++ /dev/null @@ -1,602 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "15c2d208", - "metadata": {}, - "source": [ - "# Array Refs (experimental)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "99809892", - "metadata": {}, - "outputs": [], - "source": [ - "from flax import nnx\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import optax" - ] - }, - { - "cell_type": "markdown", - "id": "787cf22a", - "metadata": {}, - "source": [ - "## Basics" - ] - }, - { - "cell_type": "markdown", - "id": "d896c926", - "metadata": {}, - "source": [ - "### Array Refs 101" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cae099ce", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1] = ArrayRef([1, 2, 3], dtype=int32)\n", - "[2] = ArrayRef([2, 3, 4], dtype=int32)\n" - ] - } - ], - "source": [ - "a_ref = jax.new_ref(jnp.array([1, 2, 3]))\n", - "\n", - "@jax.jit\n", - "def increment(a_ref: jax.Ref): # no return!\n", - " array: jax.Array = a_ref[...] # access\n", - " a_ref[...] = array + 1 # update\n", - "\n", - "print(\"[1] =\", a_ref); increment(a_ref); print(\"[2] =\", a_ref)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "fb081f49", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "module @jit_increment attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", - " func.func public @main(%arg0: tensor<3xi32> {tf.aliasing_output = 0 : i32}) -> (tensor<3xi32> {jax.result_info = \"\"}) {\n", - " %c = stablehlo.constant dense<1> : tensor\n", - " %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32>\n", - " %1 = stablehlo.add %arg0, %0 : tensor<3xi32>\n", - " return %1 : tensor<3xi32>\n", - " }\n", - "}\n", - "\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def inc(x):\n", - " x[...] += 1\n", - "\n", - "print(increment.lower(a_ref).as_text())" - ] - }, - { - "cell_type": "markdown", - "id": "26969861", - "metadata": {}, - "source": [ - "### Variables Refs" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "8c3da93c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "variable.has_ref = True\n", - "\n", - "[1] = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef([1, 2, 3], dtype=int32)\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n", - "[2] = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef([2, 3, 4], dtype=int32)\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)\n", - "print(f\"{variable.has_ref = }\\n\")\n", - "\n", - "print(\"[1] =\", variable); increment(variable); print(\"[2] =\", variable)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0a55df94", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "variable.has_ref = True\n" - ] - } - ], - "source": [ - "with nnx.use_refs(True):\n", - " variable = nnx.Variable(jnp.array([1, 2, 3]))\n", - "\n", - "print(f\"{variable.has_ref = }\")" - ] - }, - { - "cell_type": "markdown", - "id": "839332be", - "metadata": {}, - "source": [ - "Mention `nnx.use_refs` can be used as global flag" - ] - }, - { - "cell_type": "markdown", - "id": "1b2632f1", - "metadata": {}, - "source": [ - "### Changing Status" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "b7b1f421", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nnx.to_refs(model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 6 (24 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n", - "nnx.to_arrays(refs_model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 6 (24 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "class Linear(nnx.Module):\n", - " def __init__(self, in_features, out_features, rngs: nnx.Rngs):\n", - " self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features)))\n", - " self.bias = nnx.Param(jnp.zeros(out_features))\n", - "\n", - " def __call__(self, x):\n", - " return x @ self.kernel + self.bias[None]\n", - "\n", - "model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs\n", - "refs_model = nnx.to_refs(model) # convert to array refs\n", - "arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays\n", - "\n", - "print(\"nnx.to_refs(model) =\", refs_model)\n", - "print(\"nnx.to_arrays(refs_model) =\", arrays_model)" - ] - }, - { - "cell_type": "markdown", - "id": "f4e35e75", - "metadata": {}, - "source": [ - "## Examples" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5400fe58", - "metadata": {}, - "outputs": [], - "source": [ - "class Block(nnx.Module):\n", - " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", - " self.linear = Linear(din, dmid, rngs=rngs)\n", - " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", - " self.dropout = nnx.Dropout(0.1, rngs=rngs)\n", - " self.linear_out = Linear(dmid, dout, rngs=rngs)\n", - "\n", - " def __call__(self, x):\n", - " x = nnx.gelu(self.dropout(self.bn(self.linear(x))))\n", - " return self.linear_out(x)" - ] - }, - { - "cell_type": "markdown", - "id": "ba980b6b", - "metadata": {}, - "source": [ - "### Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "566c4249", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(1.000178, dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with nnx.use_refs(True):\n", - " model = Block(2, 64, 3, rngs=nnx.Rngs(0))\n", - " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", - "\n", - "@jax.jit\n", - "def train_step(model, optimizer, x, y):\n", - " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", - " def loss_fn(params):\n", - " model = nnx.merge(graphdef, params, nondiff)\n", - " return ((model(x) - y) ** 2).mean()\n", - "\n", - " loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad\n", - " optimizer.update(model, grads)\n", - "\n", - " return loss\n", - "\n", - "train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))" - ] - }, - { - "cell_type": "markdown", - "id": "1dea99c1", - "metadata": {}, - "source": [ - "### Scan Over Layers" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d8136be4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y = [[ 0.82840395 -0.25364894]\n", - " [ 4.9552917 4.93638 ]\n", - " [-7.6721525 -3.4668717 ]]\n" - ] - } - ], - "source": [ - "@nnx.vmap\n", - "def create_stack(rngs):\n", - " return Block(2, 64, 2, rngs=rngs)\n", - "\n", - "with nnx.use_refs(True):\n", - " block_stack = create_stack(nnx.Rngs(0).fork(split=8))\n", - "\n", - "def scan_fn(x, block):\n", - " x = block(x)\n", - " return x, None\n", - "\n", - "x = jax.random.uniform(jax.random.key(0), (3, 2))\n", - "y, _ = jax.lax.scan(scan_fn, x, block_stack)\n", - "\n", - "print(\"y = \", y)" - ] - }, - { - "cell_type": "markdown", - "id": "7ca18a0d", - "metadata": {}, - "source": [ - "## Limitations" - ] - }, - { - "cell_type": "markdown", - "id": "1dd39c79", - "metadata": {}, - "source": [ - "### MutableArray Outputs" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "c6062d19", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: function create_model at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1421484665.py:1 traced for jit returned a mutable array reference of type Ref{float32[64]} at output tree path result.bn.bias.value, but mutable array references cannot be returned.\n", - "\n", - "The returned mutable array was created on line /Users/cgarciae/repos/flax/flax/nnx/variablelib.py:250:17 (Variable.__init__).\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def create_model(rngs):\n", - " return Block(2, 64, 3, rngs=rngs)\n", - "\n", - "try:\n", - " with nnx.use_refs(True):\n", - " model = create_model(nnx.Rngs(0))\n", - "except Exception as e:\n", - " print(f\"Error:\", e)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "8bb1e9e7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 192 (768 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 64 (256 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m64\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 128 (512 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "with nnx.use_refs(False): # <-- disable array refs\n", - " model = create_model(nnx.Rngs(0))\n", - "\n", - "model = nnx.to_refs(model) # convert to mutable after creation\n", - "\n", - "print(\"model.linear =\", model.linear)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "3a078025", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 192 (768 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 64 (256 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m64\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 128 (512 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "@nnx.jit\n", - "def create_model(rngs):\n", - " return Block(2, 64, 3, rngs=rngs)\n", - "\n", - "with nnx.use_refs(True):\n", - " model = create_model(nnx.Rngs(0))\n", - "\n", - "print(\"model.linear =\", model.linear)" - ] - }, - { - "cell_type": "markdown", - "id": "609bed7c", - "metadata": {}, - "source": [ - "### Reference Sharing (aliasing)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "045d03c1", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing f at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1563421490.py:9 for jit the mutable array reference of type Ref{int32[]} appeared at both a and b.\n" - ] - } - ], - "source": [ - "def get_error(f, *args):\n", - " try:\n", - " return f(*args)\n", - " except Exception as e:\n", - " return f\"{type(e).__name__}: {e}\"\n", - " \n", - "x = jax.new_ref(jnp.array(0))\n", - "\n", - "@jax.jit\n", - "def f(a, b):\n", - " ...\n", - "\n", - "print(get_error(f, x, x))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "bc2e87e5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SharedVariables ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing g at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1828746469.py:13 for jit the mutable array reference of type Ref{int32[]} appeared at both pytree.a.value and pytree.c.value.\n", - "SharedModules ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing g at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1828746469.py:13 for jit the mutable array reference of type Ref{float32[1]} appeared at both pytree.d.bias.value and pytree.f.bias.value.\n" - ] - } - ], - "source": [ - "class SharedVariables(nnx.Pytree):\n", - " def __init__(self):\n", - " self.a = nnx.Variable(jnp.array(0))\n", - " self.b = nnx.Variable(jnp.array(1))\n", - " self.c = self.a\n", - "\n", - "class SharedModules(nnx.Pytree):\n", - " def __init__(self):\n", - " self.d = Linear(1, 1, rngs=nnx.Rngs(0))\n", - " self.e = Linear(1, 1, rngs=nnx.Rngs(0))\n", - " self.f = self.d\n", - "\n", - "@jax.jit\n", - "def g(pytree):\n", - " ...\n", - "\n", - "with nnx.use_refs(True):\n", - " shared_variables = SharedVariables()\n", - " shared_modules = SharedModules()\n", - "\n", - "print(\"SharedVariables\", get_error(g, shared_variables))\n", - "print(\"SharedModules\", get_error(g, shared_modules))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "6298f3d9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "shared variables duplicates: [[('a',), ('c',)]]\n", - "shared modules duplicates: [[('d',), ('f',)]]\n" - ] - } - ], - "source": [ - "if (duplicates := nnx.find_duplicates(shared_variables)):\n", - " print(\"shared variables duplicates:\", duplicates)\n", - "\n", - "if (duplicates := nnx.find_duplicates(shared_modules)):\n", - " print(\"shared modules duplicates: \", duplicates)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "00854d38", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", - " \u001b[38;2;156;220;254m'a'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(0, dtype=int32, weak_type=True)\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254m'b'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(1, dtype=int32, weak_type=True)\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m})\u001b[0m\n", - "updated \u001b[38;2;79;201;177mSharedVariables\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Variable: 2 (8 B)\u001b[0m\n", - " \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(10, dtype=int32)\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(1, dtype=int32, weak_type=True)\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mc\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(10, dtype=int32)\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def h(graphdef, state):\n", - " obj = nnx.merge(graphdef, state)\n", - " obj.a[...] += 10\n", - "\n", - "graphdef, state = nnx.split(shared_variables)\n", - "print(state) # split deduplicates the state\n", - "\n", - "h(graphdef, state)\n", - "\n", - "print(\"updated\", shared_variables)" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs_nnx/guides/array_ref.md b/docs_nnx/guides/array_ref.md deleted file mode 100644 index 1d00c77f5..000000000 --- a/docs_nnx/guides/array_ref.md +++ /dev/null @@ -1,242 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.13.8 ---- - -# Array Refs (experimental) - -```{code-cell} ipython3 -from flax import nnx -import jax -import jax.numpy as jnp -import optax -``` - -## Basics - -+++ - -### Array Refs 101 - -```{code-cell} ipython3 -a_ref = jax.new_ref(jnp.array([1, 2, 3])) - -@jax.jit -def increment(a_ref: jax.Ref): # no return! - array: jax.Array = a_ref[...] # access - a_ref[...] = array + 1 # update - -print("[1] =", a_ref); increment(a_ref); print("[2] =", a_ref) -``` - -```{code-cell} ipython3 -@jax.jit -def inc(x): - x[...] += 1 - -print(increment.lower(a_ref).as_text()) -``` - -### Variables Refs - -```{code-cell} ipython3 -variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True) -print(f"{variable.has_ref = }\n") - -print("[1] =", variable); increment(variable); print("[2] =", variable) -``` - -```{code-cell} ipython3 -with nnx.use_refs(True): - variable = nnx.Variable(jnp.array([1, 2, 3])) - -print(f"{variable.has_ref = }") -``` - -Mention `nnx.use_refs` can be used as global flag - -+++ - -### Changing Status - -```{code-cell} ipython3 -class Linear(nnx.Module): - def __init__(self, in_features, out_features, rngs: nnx.Rngs): - self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features))) - self.bias = nnx.Param(jnp.zeros(out_features)) - - def __call__(self, x): - return x @ self.kernel + self.bias[None] - -model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs -refs_model = nnx.to_refs(model) # convert to array refs -arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays - -print("nnx.to_refs(model) =", refs_model) -print("nnx.to_arrays(refs_model) =", arrays_model) -``` - -## Examples - -```{code-cell} ipython3 -class Block(nnx.Module): - def __init__(self, din, dmid, dout, rngs: nnx.Rngs): - self.linear = Linear(din, dmid, rngs=rngs) - self.bn = nnx.BatchNorm(dmid, rngs=rngs) - self.dropout = nnx.Dropout(0.1, rngs=rngs) - self.linear_out = Linear(dmid, dout, rngs=rngs) - - def __call__(self, x): - x = nnx.gelu(self.dropout(self.bn(self.linear(x)))) - return self.linear_out(x) -``` - -### Training Loop - -```{code-cell} ipython3 -with nnx.use_refs(True): - model = Block(2, 64, 3, rngs=nnx.Rngs(0)) - optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) - -@jax.jit -def train_step(model, optimizer, x, y): - graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) - def loss_fn(params): - model = nnx.merge(graphdef, params, nondiff) - return ((model(x) - y) ** 2).mean() - - loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad - optimizer.update(model, grads) - - return loss - -train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3))) -``` - -### Scan Over Layers - -```{code-cell} ipython3 -@nnx.vmap -def create_stack(rngs): - return Block(2, 64, 2, rngs=rngs) - -with nnx.use_refs(True): - block_stack = create_stack(nnx.Rngs(0).fork(split=8)) - -def scan_fn(x, block): - x = block(x) - return x, None - -x = jax.random.uniform(jax.random.key(0), (3, 2)) -y, _ = jax.lax.scan(scan_fn, x, block_stack) - -print("y = ", y) -``` - -## Limitations - -+++ - -### MutableArray Outputs - -```{code-cell} ipython3 -@jax.jit -def create_model(rngs): - return Block(2, 64, 3, rngs=rngs) - -try: - with nnx.use_refs(True): - model = create_model(nnx.Rngs(0)) -except Exception as e: - print(f"Error:", e) -``` - -```{code-cell} ipython3 -with nnx.use_refs(False): # <-- disable array refs - model = create_model(nnx.Rngs(0)) - -model = nnx.to_refs(model) # convert to mutable after creation - -print("model.linear =", model.linear) -``` - -```{code-cell} ipython3 -@nnx.jit -def create_model(rngs): - return Block(2, 64, 3, rngs=rngs) - -with nnx.use_refs(True): - model = create_model(nnx.Rngs(0)) - -print("model.linear =", model.linear) -``` - -### Reference Sharing (aliasing) - -```{code-cell} ipython3 -def get_error(f, *args): - try: - return f(*args) - except Exception as e: - return f"{type(e).__name__}: {e}" - -x = jax.new_ref(jnp.array(0)) - -@jax.jit -def f(a, b): - ... - -print(get_error(f, x, x)) -``` - -```{code-cell} ipython3 -class SharedVariables(nnx.Pytree): - def __init__(self): - self.a = nnx.Variable(jnp.array(0)) - self.b = nnx.Variable(jnp.array(1)) - self.c = self.a - -class SharedModules(nnx.Pytree): - def __init__(self): - self.d = Linear(1, 1, rngs=nnx.Rngs(0)) - self.e = Linear(1, 1, rngs=nnx.Rngs(0)) - self.f = self.d - -@jax.jit -def g(pytree): - ... - -with nnx.use_refs(True): - shared_variables = SharedVariables() - shared_modules = SharedModules() - -print("SharedVariables", get_error(g, shared_variables)) -print("SharedModules", get_error(g, shared_modules)) -``` - -```{code-cell} ipython3 -if (duplicates := nnx.find_duplicates(shared_variables)): - print("shared variables duplicates:", duplicates) - -if (duplicates := nnx.find_duplicates(shared_modules)): - print("shared modules duplicates: ", duplicates) -``` - -```{code-cell} ipython3 -@jax.jit -def h(graphdef, state): - obj = nnx.merge(graphdef, state) - obj.a[...] += 10 - -graphdef, state = nnx.split(shared_variables) -print(state) # split deduplicates the state - -h(graphdef, state) - -print("updated", shared_variables) -``` diff --git a/docs_nnx/hijax/index.rst b/docs_nnx/hijax/index.rst new file mode 100644 index 000000000..a2d374eb6 --- /dev/null +++ b/docs_nnx/hijax/index.rst @@ -0,0 +1,58 @@ +Hijax (experimental) +==================== + + + +---- + +Basic usage +^^^^^^^^^^^^ + +.. testsetup:: + + import jax + import jax.numpy as jnp + + current_mode = nnx.using_hijax() + +.. testcode:: + + from flax import nnx + import optax + + nnx.use_hijax(True) + + class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x, rngs): + x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs)) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + + @jax.jit + def train_step(model, optimizer, rngs, x, y): + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + def loss_fn(params): + model = nnx.merge(graphdef, params, nondiff) + return ((model(x, rngs) - y) ** 2).mean() + loss, grads = jax.value_and_grad(loss_fn)(nnx.immutable(params)) + optimizer.update(model, grads) # in-place updates + return loss + + nnx.use_hijax(current_mode) # clean up for CI tests + + +---- + +.. toctree:: + :hidden: + :maxdepth: 2 + + variable diff --git a/docs_nnx/hijax/variable.ipynb b/docs_nnx/hijax/variable.ipynb new file mode 100644 index 000000000..b25941f09 --- /dev/null +++ b/docs_nnx/hijax/variable.ipynb @@ -0,0 +1,680 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "15c2d208", + "metadata": {}, + "source": [ + "# Variable" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "99809892", + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "\n", + "current_mode = nnx.using_hijax()" + ] + }, + { + "cell_type": "markdown", + "id": "b36e6464", + "metadata": {}, + "source": [ + "## Hijax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "396a07a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(0), is_hijax=True)\n", + "\n", + "@jax.jit\n", + "def inc(v):\n", + " v[...] += 1\n", + "\n", + "print(v[...]); inc(v); print(v[...])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2ab7d801", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:Variable()\u001b[39m. \u001b[34;1mlet\n", + " \u001b[39;22mjit[\n", + " name=inc\n", + " jaxpr={ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:Variable()\u001b[39m. \u001b[34;1mlet\n", + " \u001b[39;22mb\u001b[35m:i32[]\u001b[39m = get_variable[avals=(ShapedArray(int32[], weak_type=True),)] a\n", + " c\u001b[35m:i32[]\u001b[39m = add b 1:i32[]\n", + " _\u001b[35m:i32[]\u001b[39m = get_variable[avals=(ShapedArray(int32[], weak_type=True),)] a\n", + " set_variable[\n", + " treedef=PyTreeDef(CustomNode(Variable[(('has_ref', False), ('is_hijax', True), ('is_mutable', True))], [*]))\n", + " var_type=\n", + " ] a c\n", + " \u001b[34;1min \u001b[39;22m() }\n", + " ] a\n", + " \u001b[34;1min \u001b[39;22m() }\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(0), is_hijax=True)\n", + "print(jax.make_jaxpr(inc)(v))" + ] + }, + { + "cell_type": "markdown", + "id": "39070460", + "metadata": {}, + "source": [ + "Pytree values:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fcd0de3f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 2 (8 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;207;144;120m'a'\u001b[0m: Array(0, dtype=int32, weak_type=True), \u001b[38;2;207;144;120m'b'\u001b[0m: Array(2, dtype=int32, weak_type=True)\u001b[38;2;255;213;3m}\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 2 (8 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;207;144;120m'a'\u001b[0m: Array(1, dtype=int32, weak_type=True), \u001b[38;2;207;144;120m'b'\u001b[0m: Array(4, dtype=int32, weak_type=True)\u001b[38;2;255;213;3m}\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, is_hijax=True)\n", + "\n", + "@jax.jit\n", + "def inc_and_double(v):\n", + " v['a'] += 1\n", + " v['b'] *= 2\n", + "\n", + "print(v); inc_and_double(v); print(v)" + ] + }, + { + "cell_type": "markdown", + "id": "f0cfe954", + "metadata": {}, + "source": [ + "Dynamic state structure:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "0d83a130", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before: \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;255;213;3m}\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "After: \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;207;144;120m'y_mean'\u001b[0m: Array(-1.1782329, dtype=float32)\u001b[38;2;255;213;3m}\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(0)\n", + "x = rngs.uniform((4, 5))\n", + "w = rngs.normal((5, 3))\n", + "metrics = nnx.Variable({}, is_hijax=True)\n", + "\n", + "@jax.jit\n", + "def linear(x, w, metrics: nnx.Variable):\n", + " y = x @ w\n", + " metrics['y_mean'] = jnp.mean(y)\n", + " return y\n", + "\n", + "print(\"Before:\", metrics)\n", + "y = linear(x, w, metrics)\n", + "print(\"After:\", metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0a55df94", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([1, 2, 3], dtype=int32),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "# set default Variable mode for the rest of the guide\n", + "nnx.use_hijax(True)\n", + "\n", + "variable = nnx.Variable(jnp.array([1, 2, 3]))\n", + "\n", + "print(variable)" + ] + }, + { + "cell_type": "markdown", + "id": "1b2632f1", + "metadata": {}, + "source": [ + "### Mutability" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b7b1f421", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nnx.immutable(model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 3 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_mutable\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mFalse\u001b[0m,\n", + " \u001b[38;2;156;220;254mwas_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "nnx.mutable(model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 3 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "class Linear(nnx.Module):\n", + " def __init__(self, in_features, out_features, rngs: nnx.Rngs):\n", + " self.kernel = nnx.Param(rngs.normal((in_features, out_features)))\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.kernel\n", + "\n", + "model = Linear(1, 3, rngs=nnx.Rngs(0))\n", + "\n", + "print(f\"{nnx.immutable(model) = !s}\")\n", + "print(f\"{nnx.mutable(model) = !s}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "594cb65e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ImmutableVariableError: Cannot mutate Variable as it is marked as immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ImmutableVariableError)\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(0))\n", + "v_immut = nnx.immutable(v)\n", + "assert not v_immut.is_mutable\n", + "\n", + "try:\n", + " v_immut[...] += 1 # raises an error\n", + "except Exception as e:\n", + " print(f\"{type(e).__name__}: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "58692a37", + "metadata": {}, + "source": [ + "### Ref" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fcd4fb4f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mhas_ref\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "Ref(0, dtype=int32, weak_type=True)\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(0))\n", + "v_ref = nnx.as_ref_vars(v)\n", + "assert v_ref.has_ref\n", + "print(v_ref)\n", + "print(v_ref.get_raw_value())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "18256668", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "immutable = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_mutable\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mFalse\u001b[0m,\n", + " \u001b[38;2;156;220;254mhad_ref\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", + " \u001b[38;2;156;220;254mwas_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "mutable = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mhas_ref\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "v_immut = nnx.immutable(v_ref)\n", + "assert not v_immut.has_ref\n", + "print(\"immutable =\", v_immut)\n", + "\n", + "v_ref = nnx.mutable(v_immut)\n", + "assert v_ref.has_ref\n", + "print(\"mutable =\", v_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f4e35e75", + "metadata": {}, + "source": [ + "### Examples" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5400fe58", + "metadata": {}, + "outputs": [], + "source": [ + "class Block(nnx.Module):\n", + " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", + " self.linear = Linear(din, dmid, rngs=rngs)\n", + " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", + " self.dropout = nnx.Dropout(0.1, rngs=rngs)\n", + " self.linear_out = Linear(dmid, dout, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " x = nnx.gelu(self.dropout(self.bn(self.linear(x))))\n", + " return self.linear_out(x)" + ] + }, + { + "cell_type": "markdown", + "id": "ba980b6b", + "metadata": {}, + "source": [ + "#### Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "566c4249", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss = 1.000178\n", + "loss = 0.9700456\n", + "loss = 0.93967044\n" + ] + } + ], + "source": [ + "# hijax Variables by default\n", + "model = Block(2, 64, 3, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@jax.jit\n", + "def train_step(model, optimizer, x, y):\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + " def loss_fn(params):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " return ((model(x) - y) ** 2).mean()\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(nnx.immutable(params)) # lojax Variables for jax.grad\n", + " optimizer.update(model, grads)\n", + "\n", + " return loss\n", + "\n", + "for _ in range(3):\n", + " loss = train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))\n", + " print(f\"{loss = !s}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1dea99c1", + "metadata": {}, + "source": [ + "#### Scan Over Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d8136be4", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: does not work with hijax yet\n", + "# @jax.vmap\n", + "# def create_stack(rngs):\n", + "# return nnx.immutable(Block(2, 64, 2, rngs=rngs))\n", + "\n", + "# block_stack = nnx.mutable(create_stack(nnx.Rngs(0).fork(split=8)))\n", + "\n", + "# def scan_fn(x, block):\n", + "# x = block(x)\n", + "# return x, None\n", + "\n", + "# x = jax.random.uniform(jax.random.key(0), (3, 2))\n", + "# y, _ = jax.lax.scan(scan_fn, x, block_stack)\n", + "\n", + "# print(\"y = \", y)" + ] + }, + { + "cell_type": "markdown", + "id": "7ca18a0d", + "metadata": {}, + "source": [ + "### Limitations" + ] + }, + { + "cell_type": "markdown", + "id": "1dd39c79", + "metadata": {}, + "source": [ + "#### Mutable Outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c6062d19", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error: mutable hitypes should use lo_ty_qdd instead\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def create_model(rngs):\n", + " return Block(2, 64, 3, rngs=rngs)\n", + "\n", + "try:\n", + " model = create_model(nnx.Rngs(0))\n", + "except Exception as e:\n", + " print(f\"Error:\", e)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8bb1e9e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 128 (512 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def create_model(rngs):\n", + " return nnx.immutable(Block(2, 64, 3, rngs=rngs))\n", + "\n", + "model = nnx.mutable(create_model(nnx.Rngs(0)))\n", + "\n", + "print(\"model.linear =\", model.linear)" + ] + }, + { + "cell_type": "markdown", + "id": "609bed7c", + "metadata": {}, + "source": [ + "#### Reference Sharing (aliasing)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "045d03c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], + "source": [ + "# NOTE: doesn't currently fail on the jax side\n", + "def get_error(f, *args):\n", + " try:\n", + " return f(*args)\n", + " except Exception as e:\n", + " return f\"{type(e).__name__}: {e}\"\n", + "\n", + "x = nnx.Variable(jnp.array(0))\n", + "\n", + "@jax.jit\n", + "def f(a, b):\n", + " ...\n", + "\n", + "print(get_error(f, x, x))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bc2e87e5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], + "source": [ + "# NOTE: doesn't currently fail on the jax side\n", + "class Shared(nnx.Pytree):\n", + " def __init__(self):\n", + " self.a = nnx.Variable(jnp.array(0))\n", + " self.b = self.a\n", + " self.c = Linear(1, 1, rngs=nnx.Rngs(0))\n", + " self.d = self.c\n", + "\n", + "@jax.jit\n", + "def g(pytree):\n", + " ...\n", + "\n", + "shared = Shared()\n", + "\n", + "print(get_error(g, shared))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6298f3d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Duplicates found:\n", + "- [('a',), ('b',)]\n", + "- [('c',), ('d',)]\n" + ] + } + ], + "source": [ + "print(\"Duplicates found:\")\n", + "if (all_duplicates := nnx.find_duplicates(shared)):\n", + " for duplicates in all_duplicates:\n", + " print(\"-\", duplicates)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "00854d38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "before: \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "after: \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(10, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def h(graphdef, state):\n", + " obj = nnx.merge(graphdef, state)\n", + " obj.a[...] += 10\n", + "\n", + "graphdef, state = nnx.split(shared)\n", + "print(\"before:\", state.a) # split deduplicates the state\n", + "\n", + "h(graphdef, state)\n", + "\n", + "print(\"after:\", shared.a)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "195296c8", + "metadata": {}, + "outputs": [], + "source": [ + "# clean up for CI tests\n", + "_ = nnx.use_hijax(current_mode)" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/hijax/variable.md b/docs_nnx/hijax/variable.md new file mode 100644 index 000000000..22053c0a5 --- /dev/null +++ b/docs_nnx/hijax/variable.md @@ -0,0 +1,274 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Variable + +```{code-cell} ipython3 +from flax import nnx +import jax +import jax.numpy as jnp +import optax + +current_mode = nnx.using_hijax() +``` + +## Hijax + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(0), is_hijax=True) + +@jax.jit +def inc(v): + v[...] += 1 + +print(v[...]); inc(v); print(v[...]) +``` + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(0), is_hijax=True) +print(jax.make_jaxpr(inc)(v)) +``` + +Pytree values: + +```{code-cell} ipython3 +v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, is_hijax=True) + +@jax.jit +def inc_and_double(v): + v['a'] += 1 + v['b'] *= 2 + +print(v); inc_and_double(v); print(v) +``` + +Dynamic state structure: + +```{code-cell} ipython3 +rngs = nnx.Rngs(0) +x = rngs.uniform((4, 5)) +w = rngs.normal((5, 3)) +metrics = nnx.Variable({}, is_hijax=True) + +@jax.jit +def linear(x, w, metrics: nnx.Variable): + y = x @ w + metrics['y_mean'] = jnp.mean(y) + return y + +print("Before:", metrics) +y = linear(x, w, metrics) +print("After:", metrics) +``` + +```{code-cell} ipython3 +# set default Variable mode for the rest of the guide +nnx.use_hijax(True) + +variable = nnx.Variable(jnp.array([1, 2, 3])) + +print(variable) +``` + +### Mutability + +```{code-cell} ipython3 +class Linear(nnx.Module): + def __init__(self, in_features, out_features, rngs: nnx.Rngs): + self.kernel = nnx.Param(rngs.normal((in_features, out_features))) + + def __call__(self, x): + return x @ self.kernel + +model = Linear(1, 3, rngs=nnx.Rngs(0)) + +print(f"{nnx.immutable(model) = !s}") +print(f"{nnx.mutable(model) = !s}") +``` + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(0)) +v_immut = nnx.immutable(v) +assert not v_immut.is_mutable + +try: + v_immut[...] += 1 # raises an error +except Exception as e: + print(f"{type(e).__name__}: {e}") +``` + +### Ref + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(0)) +v_ref = nnx.as_ref_vars(v) +assert v_ref.has_ref +print(v_ref) +print(v_ref.get_raw_value()) +``` + +```{code-cell} ipython3 +v_immut = nnx.immutable(v_ref) +assert not v_immut.has_ref +print("immutable =", v_immut) + +v_ref = nnx.mutable(v_immut) +assert v_ref.has_ref +print("mutable =", v_ref) +``` + +### Examples + +```{code-cell} ipython3 +class Block(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.1, rngs=rngs) + self.linear_out = Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.gelu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) +``` + +#### Training Loop + +```{code-cell} ipython3 +# hijax Variables by default +model = Block(2, 64, 3, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@jax.jit +def train_step(model, optimizer, x, y): + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + def loss_fn(params): + model = nnx.merge(graphdef, params, nondiff) + return ((model(x) - y) ** 2).mean() + + loss, grads = jax.value_and_grad(loss_fn)(nnx.immutable(params)) # lojax Variables for jax.grad + optimizer.update(model, grads) + + return loss + +for _ in range(3): + loss = train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3))) + print(f"{loss = !s}") +``` + +#### Scan Over Layers + +```{code-cell} ipython3 +# TODO: does not work with hijax yet +# @jax.vmap +# def create_stack(rngs): +# return nnx.immutable(Block(2, 64, 2, rngs=rngs)) + +# block_stack = nnx.mutable(create_stack(nnx.Rngs(0).fork(split=8))) + +# def scan_fn(x, block): +# x = block(x) +# return x, None + +# x = jax.random.uniform(jax.random.key(0), (3, 2)) +# y, _ = jax.lax.scan(scan_fn, x, block_stack) + +# print("y = ", y) +``` + +### Limitations + ++++ + +#### Mutable Outputs + +```{code-cell} ipython3 +@jax.jit +def create_model(rngs): + return Block(2, 64, 3, rngs=rngs) + +try: + model = create_model(nnx.Rngs(0)) +except Exception as e: + print(f"Error:", e) +``` + +```{code-cell} ipython3 +@jax.jit +def create_model(rngs): + return nnx.immutable(Block(2, 64, 3, rngs=rngs)) + +model = nnx.mutable(create_model(nnx.Rngs(0))) + +print("model.linear =", model.linear) +``` + +#### Reference Sharing (aliasing) + +```{code-cell} ipython3 +# NOTE: doesn't currently fail on the jax side +def get_error(f, *args): + try: + return f(*args) + except Exception as e: + return f"{type(e).__name__}: {e}" + +x = nnx.Variable(jnp.array(0)) + +@jax.jit +def f(a, b): + ... + +print(get_error(f, x, x)) +``` + +```{code-cell} ipython3 +# NOTE: doesn't currently fail on the jax side +class Shared(nnx.Pytree): + def __init__(self): + self.a = nnx.Variable(jnp.array(0)) + self.b = self.a + self.c = Linear(1, 1, rngs=nnx.Rngs(0)) + self.d = self.c + +@jax.jit +def g(pytree): + ... + +shared = Shared() + +print(get_error(g, shared)) +``` + +```{code-cell} ipython3 +print("Duplicates found:") +if (all_duplicates := nnx.find_duplicates(shared)): + for duplicates in all_duplicates: + print("-", duplicates) +``` + +```{code-cell} ipython3 +@jax.jit +def h(graphdef, state): + obj = nnx.merge(graphdef, state) + obj.a[...] += 10 + +graphdef, state = nnx.split(shared) +print("before:", state.a) # split deduplicates the state + +h(graphdef, state) + +print("after:", shared.a) +``` + +```{code-cell} ipython3 +# clean up for CI tests +_ = nnx.use_hijax(current_mode) +``` diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index 58490edc1..d2714a411 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -107,15 +107,11 @@ Basic usage model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) - @nnx.jit # automatic state management for JAX transforms + @nnx.jit # automatic state propagation def train_step(model, optimizer, x, y): - def loss_fn(model): - y_pred = model(x) # call methods directly - return ((y_pred - y) ** 2).mean() - + loss_fn = lambda model: ((model(x) - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates - return loss @@ -197,6 +193,7 @@ Learn more key_concepts guides_basic guides_advanced + hijax/index migrating/index examples/index nnx_glossary diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index ab43de2b5..727478d45 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ " self.din, self.dout = din, dout\n", "\n", " def __call__(self, x: jax.Array):\n", - " return x @ self.w + self.b" + " return x @ self.w + self.b[None]" ] }, { @@ -84,31 +84,10 @@ "[[1.5643291 0.94782424 0.37971854 1.0724319 0.22112393]]\n" ] }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook :\n", - "Traceback (most recent call last):\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/renderers.py\", line 225, in _render_subtree\n", - " postprocessed_result = hook(\n", - " ^^^^^\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py\", line 47, in use_autovisualizer_if_present\n", - " result = autoviz(node, path)\n", - " ^^^^^^^^^^^^^^^^^^^\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py\", line 306, in __call__\n", - " jax.sharding.PositionalSharding\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/deprecations.py\", line 54, in getattr\n", - " raise AttributeError(message)\n", - "AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0\n", - "\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -120,7 +99,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -164,8 +143,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "counter.count.value = Array(0, dtype=int32, weak_type=True)\n", - "counter.count.value = Array(1, dtype=int32, weak_type=True)\n" + "counter.count[...] = Array(0, dtype=int32, weak_type=True)\n", + "counter.count[...] = Array(1, dtype=int32, weak_type=True)\n" ] } ], @@ -177,12 +156,12 @@ " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self):\n", - " self.count.value += 1\n", + " self.count[...] += 1\n", "\n", "counter = Counter()\n", - "print(f'{counter.count.value = }')\n", + "print(f'{counter.count[...] = }')\n", "counter()\n", - "print(f'{counter.count.value = }')" + "print(f'{counter.count[...] = }')" ] }, { @@ -212,7 +191,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -224,7 +203,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -273,13 +252,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -291,7 +270,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -340,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -415,7 +394,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -427,7 +406,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -480,13 +459,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -498,7 +477,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -547,7 +526,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -559,7 +538,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -571,7 +550,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -672,7 +651,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -684,7 +663,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -696,7 +675,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -708,7 +687,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index 149e859d2..c2aca4cae 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -40,7 +40,7 @@ class Linear(nnx.Module): self.din, self.dout = din, dout def __call__(self, x: jax.Array): - return x @ self.w + self.b + return x @ self.w + self.b[None] ``` Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above). @@ -73,12 +73,12 @@ class Counter(nnx.Module): self.count = Count(jnp.array(0)) def __call__(self): - self.count.value += 1 + self.count[...] += 1 counter = Counter() -print(f'{counter.count.value = }') +print(f'{counter.count[...] = }') counter() -print(f'{counter.count.value = }') +print(f'{counter.count[...] = }') ``` Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms diff --git a/examples/gemma/helpers.py b/examples/gemma/helpers.py index f74845bed..eeb02848d 100644 --- a/examples/gemma/helpers.py +++ b/examples/gemma/helpers.py @@ -74,7 +74,7 @@ def assign_val_fn( mapped_path: tuple[str | int, ...], val: Any, ) -> dict[tuple[str, ...], Any]: - state[mapped_path].value = val + state[mapped_path].set_value(val) return state mdl: M = nnx.eval_shape(module_factory) diff --git a/examples/gemma/helpers_test.py b/examples/gemma/helpers_test.py index 8d5e899f9..dd7e5fe4e 100644 --- a/examples/gemma/helpers_test.py +++ b/examples/gemma/helpers_test.py @@ -137,11 +137,11 @@ def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]: np.testing.assert_array_equal(output, linen_output) for i in range(len(num_features)): np.testing.assert_array_equal( - mdl.layers[i].layers[0].mean.value, + mdl.layers[i].layers[0].mean[...], linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'], ) np.testing.assert_array_equal( - mdl.layers[i].layers[0].var.value, + mdl.layers[i].layers[0].var[...], linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'], ) diff --git a/examples/gemma/layers.py b/examples/gemma/layers.py index f764c61a0..5fb959ada 100644 --- a/examples/gemma/layers.py +++ b/examples/gemma/layers.py @@ -44,11 +44,11 @@ def __init__( self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype)) def __call__(self, x: ArrayLike) -> Array: - return jnp.einsum(self.einsum_str, x, self.w.value) + return jnp.einsum(self.einsum_str, x, self.w[...]) @property def shape(self) -> Shape: - return self.w.value.shape + return self.w.shape class RMSNorm(nnx.Module): @@ -65,12 +65,12 @@ def __init__( self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype)) def __call__(self, x: Array) -> Array: - dtype = self.scale.value.dtype + dtype = self.scale.dtype var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype) # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs. - scale = jnp.expand_dims(self.scale.value, axis=range(len(x.shape) - 1)) + scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1)) normed_inputs = normed_inputs * (1 + scale) return normed_inputs diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index 56c426fbd..48c9a018a 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -63,15 +63,15 @@ def encode(self, x: ArrayLike) -> Array: return x def decode(self, x: ArrayLike) -> Array: - return jnp.dot(x, self.input_embedding.value.T) + return jnp.dot(x, self.input_embedding.T) @property def embed_dim(self): - return self.input_embedding.value.shape[1] + return self.input_embedding.shape[1] @property def num_embed(self): - return self.input_embedding.value.shape[0] + return self.input_embedding.shape[0] class Attention(nnx.Module): diff --git a/examples/gemma/sampler_test.py b/examples/gemma/sampler_test.py index 307b0e43e..8d2ed5a83 100644 --- a/examples/gemma/sampler_test.py +++ b/examples/gemma/sampler_test.py @@ -232,9 +232,9 @@ def test_forbidden_tokens(self): transformer_config, rngs=nnx.Rngs(params=0) ) # Pre-cook the embedding matrix so that the output is deterministic. - transformer.embedder.input_embedding.value = jnp.eye( + transformer.embedder.input_embedding.set_value(jnp.eye( vocab.GetPieceSize(), 32 - ) + )) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, diff --git a/examples/gemma/sow_lib.py b/examples/gemma/sow_lib.py index 6bc808501..7580cdfe2 100644 --- a/examples/gemma/sow_lib.py +++ b/examples/gemma/sow_lib.py @@ -49,13 +49,11 @@ def merge(self, decoding_step, layer: nnx.Module): if field.name.startswith('attn_'): step_value = getattr( layer.attn, field.name.replace('attn_', '') - ).value[0] + )[0] elif field.name.startswith('mlp_'): - step_value = getattr(layer.mlp, field.name.replace('mlp_', '')).value[ - 0 - ] + step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0] else: - step_value = getattr(layer, field.name).value[0] + step_value = getattr(layer, field.name)[0] except AttributeError as exc: raise ValueError( f'Intermediate {field.name} is not in the step intermediates.' @@ -93,7 +91,7 @@ def merge(self, decoding_step, transformer: nnx.Module): if self.embeddings is not None: try: self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set( - transformer.embeddings.value[0][:, 0, ...] + transformer.embeddings[0][:, 0, ...] ) except AttributeError as exc: raise ValueError( diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index 842121e96..54fb6748e 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -487,10 +487,10 @@ def _assign_linen_params_to_nnx_state( if 'gate_proj' in mapped_path: if transpose_gating_einsum: val = jnp.swapaxes(val, 1, 2) - state[mapped_path].value = val[0] - state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] + state[mapped_path].set_value(val[0]) + state[mapped_path[:-2] + ('up_proj', 'kernel')].set_value(val[1]) else: - state[mapped_path].value = val + state[mapped_path].set_value(val) return state diff --git a/examples/gemma/transformer_test.py b/examples/gemma/transformer_test.py index 3d30f9277..97916604b 100644 --- a/examples/gemma/transformer_test.py +++ b/examples/gemma/transformer_test.py @@ -461,7 +461,7 @@ def test_sow_intermediates(self, sow_config): if sow_config.embeddings: self.assertTrue(hasattr(transformer, 'embeddings')) - embeddings = transformer.embeddings.value[0] + embeddings = transformer.embeddings[0] self.assertEqual( embeddings.shape, (batch_size, sequence_length, config.embed_dim), @@ -472,7 +472,7 @@ def test_sow_intermediates(self, sow_config): for layer in transformer.layers: if sow_config.rs_after_attention: self.assertTrue(hasattr(layer, 'rs_after_attention')) - rs_after_attention = layer.rs_after_attention.value[0] + rs_after_attention = layer.rs_after_attention[0] self.assertIsNotNone(rs_after_attention) self.assertEqual( rs_after_attention.shape, @@ -482,7 +482,7 @@ def test_sow_intermediates(self, sow_config): self.assertFalse(hasattr(layer, 'rs_after_attention')) if sow_config.rs_after_ffw: self.assertTrue(hasattr(layer, 'rs_after_ffw')) - rs_after_ffw = layer.rs_after_ffw.value[0] + rs_after_ffw = layer.rs_after_ffw[0] self.assertIsNotNone(rs_after_ffw) self.assertEqual( rs_after_ffw.shape, @@ -492,7 +492,7 @@ def test_sow_intermediates(self, sow_config): self.assertFalse(hasattr(layer, 'rs_after_ffw')) if sow_config.attn_logits_topk: self.assertTrue(hasattr(layer.attn, 'logits_topk_values')) - attn_logits_topk_values = layer.attn.logits_topk_values.value[0] + attn_logits_topk_values = layer.attn.logits_topk_values[0] self.assertIsNotNone(attn_logits_topk_values) self.assertEqual( attn_logits_topk_values.shape, @@ -504,7 +504,7 @@ def test_sow_intermediates(self, sow_config): ), ) self.assertTrue(hasattr(layer.attn, 'logits_topk_indices')) - attn_logits_topk_indices = layer.attn.logits_topk_indices.value[0] + attn_logits_topk_indices = layer.attn.logits_topk_indices[0] self.assertIsNotNone(attn_logits_topk_indices) self.assertEqual( attn_logits_topk_indices.shape, @@ -520,7 +520,7 @@ def test_sow_intermediates(self, sow_config): self.assertFalse(hasattr(layer.attn, 'logits_topk_indices')) if sow_config.mlp_hidden_topk: self.assertTrue(hasattr(layer.mlp, 'hidden_topk_values')) - ffw_hidden_topk_values = layer.mlp.hidden_topk_values.value[0] + ffw_hidden_topk_values = layer.mlp.hidden_topk_values[0] self.assertIsNotNone(ffw_hidden_topk_values) self.assertEqual( ffw_hidden_topk_values.shape, @@ -531,7 +531,7 @@ def test_sow_intermediates(self, sow_config): ), ) self.assertTrue(hasattr(layer.mlp, 'hidden_topk_indices')) - ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices.value[0] + ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices[0] self.assertIsNotNone(ffw_hidden_topk_indices) self.assertEqual( ffw_hidden_topk_indices.shape, diff --git a/examples/nnx_toy_examples/mutable_array_basic.py b/examples/nnx_toy_examples/hijax_basic.py similarity index 81% rename from examples/nnx_toy_examples/mutable_array_basic.py rename to examples/nnx_toy_examples/hijax_basic.py index 7386163c1..e267976c3 100644 --- a/examples/nnx_toy_examples/mutable_array_basic.py +++ b/examples/nnx_toy_examples/hijax_basic.py @@ -33,11 +33,11 @@ def dataset(batch_size): class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): - self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.w = nnx.Param(rngs.params.uniform((din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - return x @ self.w[...] + self.b[None] + return x @ self.w + self.b[None] class Count(nnx.Variable[nnx.A]): @@ -54,23 +54,23 @@ def __call__(self, x): self.count[...] += 1 return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5) -with nnx.use_refs(True): - model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) - optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.1), wrt=nnx.Param) +nnx.use_hijax(True) + +model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.1), wrt=nnx.Param) @jax.jit def train_step(model, optimizer, x, y): - graphdef, params, counts = nnx.split(model, nnx.Param, Count) + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): - model = nnx.merge(graphdef, params, counts) + model = nnx.merge(graphdef, params, nondiff) return jnp.mean((y - model(x)) ** 2) - grads = jax.grad(loss_fn)(nnx.to_arrays(params)) + grads = jax.grad(loss_fn)(nnx.immutable(params)) optimizer.update(model, grads) - @jax.jit def test_step(model: MLP, x, y): return {'loss': jnp.mean((y - model(x)) ** 2)} @@ -87,7 +87,7 @@ def test_step(model: MLP, x, y): if step >= total_steps - 1: break -print('times called:', model.count.value) +print('times called:', model.count[...]) y_pred = model(X) diff --git a/examples/nnx_toy_examples/mutable_array_demo.py b/examples/nnx_toy_examples/hijax_demo.py similarity index 81% rename from examples/nnx_toy_examples/mutable_array_demo.py rename to examples/nnx_toy_examples/hijax_demo.py index 6d9619444..5b14be4f0 100644 --- a/examples/nnx_toy_examples/mutable_array_demo.py +++ b/examples/nnx_toy_examples/hijax_demo.py @@ -49,9 +49,8 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(initializer(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - # [...] is used to access the array def __call__(self, x: jax.Array): - return x @ self.w[...] + self.b[None] + return x @ self.w + self.b[None] # Block implements linear, batch norm, and dropout. Its behavior @@ -88,21 +87,21 @@ def __call__( self, x: jax.Array, *, rngs: nnx.Rngs | None = None ) -> jax.Array: # ----------- linear -------------------- - x = x @ self.w[...] + self.b[None] + x = x @ self.w + self.b[None] # ----------- batch norm ---------------- if self.use_stats: - mean = self.mean[...] - var = self.var[...] + mean = self.mean + var = self.var else: mean = jnp.mean(x, axis=0) var = jnp.var(x, axis=0) # ema updates - # stop gradient is used until a ArrayRef supports updates from grad tracers + # stop gradient is used until a Hijax supports updates from grad tracers sg = jax.lax.stop_gradient - self.mean[...] = sg(self.mu * self.mean[...] + (1 - self.mu) * mean) - self.var[...] = sg(self.mu * self.var[...] + (1 - self.mu) * var) + self.mean[...] = sg(self.mu * self.mean + (1 - self.mu) * mean) + self.var[...] = sg(self.mu * self.var + (1 - self.mu) * var) x = (x - mean[None]) / jnp.sqrt(var[None] + 1e-5) - x = x * self.scale[...] + self.bias[...] + x = x * self.scale + self.bias # ----------- dropout ------------------- if not self.deterministic and self.dropout_rate > 0.0: assert rngs is not None @@ -125,7 +124,7 @@ def __init__( use_scan: bool = True, rngs: nnx.Rngs, ): - self.count: jax.Ref = jax.new_ref(jnp.array(0)) + self.count = nnx.Variable(jnp.array(0)) self.block_in = Block(din, dhidden, rngs=rngs) self.linear_out = Linear(dhidden, dout, rngs=rngs) @@ -136,11 +135,15 @@ def __init__( @jax.vmap def create_block(rngs, /): - return nnx.to_arrays(Block(dhidden, dhidden, rngs=rngs)) + # return nnx.stateless(Block(dhidden, dhidden, rngs=rngs)) + return Block(dhidden, dhidden, rngs=rngs) - self.blocks = nnx.to_refs(create_block(rngs.fork(split=num_blocks))) + # self.blocks = nnx.stateful(create_block(rngs.fork(split=num_blocks))) + self.blocks = create_block(rngs.fork(split=num_blocks)) else: - self.blocks = nnx.List([Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]) + self.blocks = nnx.List( + [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)] + ) def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None): self.count[...] += 1 @@ -169,7 +172,7 @@ class OptState(nnx.Variable): ... # Optimizer are an interesting case as they are inherently stateful and -# pose a good use case for ArrayRef. Here we implement SGD with +# pose a good use case for MutableHijax. Here we implement SGD with # momentum. The optimizer receives the params as constructor arguments but doesn't # hold a reference to them, it only uses the params to initialize its state # by creating new OptState Variables that reuse the param's metadata. @@ -180,40 +183,36 @@ def __init__(self, params, lr: float, decay: float = 0.9): def make_opt_state(x): if isinstance(x, nnx.Variable): - return OptState(jnp.zeros_like(x.value), **x.get_metadata()) + return OptState(jnp.zeros_like(x[...]), **x.get_metadata()) else: return OptState(jnp.zeros_like(x)) - self.momentum = nnx.data(jax.tree.map( - make_opt_state, - params, - is_leaf=lambda x: isinstance(x, nnx.Variable), - )) + self.momentum = nnx.data(jax.tree.map(make_opt_state, params)) # during the update we simply map over (params, momentum, grads), # for each triplet we implement the SGD update rule which updates # both the optimizer's state (momentum) and the params in place. def update(self, params, grads): - params = nnx.pure(params) - grads = nnx.pure(grads) - momentum = nnx.pure(self.momentum) - def update_fn( - param: jax.Ref, momentum: jax.Ref, grad: jax.Array + param: nnx.Variable[jax.Array], + momentum: nnx.Variable[jax.Array], + grad: nnx.Variable[jax.Array], ): - momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...] - param[...] -= self.lr * momentum[...] + momentum[...] = self.decay * momentum + (1 - self.decay) * grad + param[...] -= self.lr * momentum + + # is_leaf might not be necesarry as MutableHijaxVariable are not pytreees + jax.tree.map(update_fn, params, self.momentum, grads) - jax.tree.map(update_fn, params, momentum, grads) # ## Training +nnx.use_hijax(True) -with nnx.use_refs(True): - rngs = nnx.Rngs(params=0, dropout=1) - model = Model( - num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs - ) - optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99) +rngs = nnx.Rngs(params=0, dropout=1) +model = Model( + num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs +) +optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99) # Create a copy of the model structure and set its attributes to eval model. # This works because they share the underlying ArrayRefs so both models @@ -237,13 +236,14 @@ def loss_fn(params): loss = jnp.mean((model(x, rngs=rngs) - y) ** 2) return loss - # For the time being we have to use 'freeze' make the Variables immutable - # as 'jax.grad' doesn't support ArrayRefs yet. - grads = jax.grad(loss_fn)(nnx.to_arrays(params)) + # For the time being we have to use 'immutable' + # as 'jax.grad' doesn't support QDD types yet. + grads = jax.grad(loss_fn)(nnx.immutable(params)) # 'update' mutates the optimizer's state and the params in place # so we don't need to return anything 🚀 optimizer.update(params, grads) + # simple test step that computes the loss @jax.jit def test_step(model: Model, x, y): diff --git a/flax/configurations.py b/flax/configurations.py index 9240b697f..131a850bb 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -201,6 +201,38 @@ def static_bool_env(varname: str, default: bool) -> bool: ) +def str_flag(name: str, *, default: str, help: str) -> FlagHolder[str]: + """Set up a string flag. + + Example:: + + some_string = str_flag( + name='flax_some_string', + default='default_value', + help='Some string configuration.', + ) + + Now the ``FLAX_SOME_STRING`` shell environment variable can be used to + control the process-level value of the flag, in addition to using e.g. + ``config.update("flax_some_string", "new_value")`` directly. + + Args: + name: converted to lowercase to define the name of the flag. It is + converted to uppercase to define the corresponding shell environment + variable. + default: a default value for the flag. + help: used to populate the docstring of the returned flag holder object. + + Returns: + A flag holder object for accessing the value of the flag. + """ + name = name.lower() + config._add_option(name, static_str_env(name.upper(), default)) + fh = FlagHolder[str](name, help) + setattr(Config, name, property(lambda _: fh.value, doc=help)) + return fh + + def static_int_env(varname: str, default: int | None) -> int | None: """Read an environment variable and interpret it as an integer. @@ -222,6 +254,18 @@ def static_int_env(varname: str, default: int | None) -> int | None: ) from None +def static_str_env(varname: str, default: str) -> str: + """Read an environment variable and interpret it as a string. + + Args: + varname: the name of the variable + default: the default string value + Returns: + string return value derived from defaults and environment. + """ + return os.getenv(varname, default) + + # Flax Global Configuration Variables: flax_filter_frames = bool_flag( @@ -294,5 +338,5 @@ def static_int_env(varname: str, default: int | None) -> int | None: flax_hijax_variable = bool_flag( name='flax_hijax_variable', default=False, - help='Whether to enable HiJAX support for `nnx.Variable`.', + help='Whether to use hijax for `nnx.Variable`. Options are "pytree", "hijax", and "ref".', ) \ No newline at end of file diff --git a/flax/errors.py b/flax/errors.py index e34be0cd3..6e9d2b061 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -351,6 +351,22 @@ def __init__(self, col, variable_name, scope_path): ) +class ImmutableVariableError(FlaxError): + """You cannot update a variable that is marked as immutable. + + This error occurs when attempting to modify a Variable that has been set to + 'immutable' mode. Variables in immutable mode are read-only and cannot be + changed after creation. + + To fix this error, either: + 1. Use a different variable mode (e.g., 'qdd' or 'pytree') + 2. Or ensure you're not trying to modify the variable's value + """ + + def __init__(self, message): + super().__init__(message) + + class JaxTransformError(FlaxError): """JAX transforms and Flax modules cannot be mixed. diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index f162dc3ef..a203b0576 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -68,8 +68,12 @@ from .graph import MergeContext as MergeContext from .graph import merge_context as merge_context from .graph import variables as variables -from .graph import to_arrays as to_arrays -from .graph import to_refs as to_refs +from .graph import as_ref_vars as as_ref_vars +from .graph import as_array_vars as as_array_vars +from .graph import as_hijax_vars as as_hijax_vars +from .graph import as_pytree_vars as as_pytree_vars +from .graph import immutable as immutable +from .graph import mutable as mutable from .graph import pure as pure from .graph import cached_partial as cached_partial from .graph import flatten as flatten @@ -197,8 +201,8 @@ from .variablelib import variable_type_from_name as variable_type_from_name from .variablelib import variable_name_from_type as variable_name_from_type from .variablelib import register_variable_name as register_variable_name -from .variablelib import use_refs as use_refs -from .variablelib import using_refs as using_refs +from .variablelib import use_hijax as use_hijax +from .variablelib import using_hijax as using_hijax from .visualization import display as display from .extract import to_tree as to_tree from .extract import from_tree as from_tree diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index ed30d2895..bf913e386 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -391,7 +391,7 @@ def _get_variables(self) -> tp.Mapping: if isinstance( variable, variablelib.Variable ) and bridge_variables.is_vanilla_variable(variable): - leaf = variable.value + leaf = variable.get_value() else: leaf = bridge_variables.to_linen_var(variable) diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index e354f2e70..dd7624e86 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -64,7 +64,7 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': def get_partition_spec(self) -> jax.sharding.PartitionSpec: """Returns the ``Partitionspec`` for this partitioned value.""" nnx_var = self.to_nnx_variable() - spec = spmd.get_partition_spec(nnx_var).raw_value + spec = spmd.get_partition_spec(nnx_var).get_raw_value() assert isinstance(spec, jax.sharding.PartitionSpec) return spec @@ -78,11 +78,11 @@ def is_vanilla_variable(vs: variablelib.Variable) -> bool: Returns False only if it has non-empty hooks or any non-built-in attribute. """ for key, value in vs.get_metadata().items(): - if key.endswith('_hooks'): - if value != (): - return False - else: - return False + if key in ('is_hijax', 'has_ref', 'is_mutable', 'eager_sharding'): + continue + if key.endswith('_hooks') and value == (): + continue + return False return True @@ -91,11 +91,11 @@ def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: if 'linen_meta_type' in metadata: linen_type = metadata['linen_meta_type'] if hasattr(linen_type, 'from_nnx_metadata'): - return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) - return linen_type(vs.value, **metadata) + return linen_type.from_nnx_metadata({'value': vs.get_value(), **metadata}) + return linen_type(vs.get_value(), **metadata) if is_vanilla_variable(vs): - return vs.value - return NNXMeta(type(vs), vs.value, metadata) + return vs.get_value() + return NNXMeta(type(vs), vs.get_value(), metadata) def get_col_name(keypath: tp.Sequence[Any]) -> str: diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 6ab594e81..50a1041e5 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -413,7 +413,7 @@ def _to_linen_var(x): if self.metadata_fn is not None: return self.metadata_fn(x) # pylint: disable=too-many-function-args else: - return x.value + return x.get_value() return x collection_state = nnx.traversals.unflatten_mapping(flat_state) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index fd8db8124..061992227 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -63,7 +63,7 @@ def check_consistent_aliasing( lambda: f'Trying to extract graph node from different trace level, got {value!r}' ) if isinstance(value, graph.Variable): - if not value._trace_state.is_valid(): + if not value._can_update: raise ValueError( f'Cannot extract graph node from different trace level, got {value!r}' ) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 1d6466003..1b1b27622 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -20,7 +20,7 @@ import threading import typing as tp -import jax.experimental +import jax.core from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib @@ -261,7 +261,11 @@ def is_node(x: tp.Any) -> bool: def is_graph_node(x: tp.Any) -> bool: - return type(x) in GRAPH_REGISTRY or variablelib.is_array_ref(x) or isinstance(x, variablelib.Variable) + return ( + type(x) in GRAPH_REGISTRY + or variablelib.is_array_ref(x) + or isinstance(x, Variable) + ) def is_node_type(x: type[tp.Any]) -> bool: @@ -761,7 +765,7 @@ def make_mutable_arraydef(value: variablelib.Ref): if is_variable: assert isinstance(node, Variable) assert index is not None - prev_inner_value = node.raw_value + prev_inner_value = node.get_raw_value() if variablelib.is_array_ref(prev_inner_value): array_refdef, inner_value = make_mutable_arraydef(prev_inner_value) else: @@ -772,13 +776,13 @@ def make_mutable_arraydef(value: variablelib.Ref): else: leaf = node # type: ignore[assignment] if inner_value is not prev_inner_value: - leaf.raw_value = inner_value + leaf.set_raw_value(inner_value) variabledef = VariableDef( - type=type(node), + type=node.var_type, # type: ignore index=index, outer_index=ref_outer_index.get(node, None) if ref_outer_index else None, - metadata=HashableMapping(node._var_metadata), + metadata=HashableMapping(node.get_metadata()), array_refdef=array_refdef, ) if type(inner_value) is not Repeated: @@ -944,7 +948,7 @@ def _graph_fingerprint( variable_index = new_ref_index[value] = ctx.next_index ctx.next_index += 1 append_fn(variable_index) - for key_value in value._var_metadata.items(): + for key_value in value.get_metadata().items(): append_fn(key_value) elif not isinstance(value, (jax.Array, np.ndarray)): append_fn(value) @@ -1048,7 +1052,7 @@ def _check_graph_fingerprint( # append_fn(variable_index) if variable_index != next(fp_iterator): return False - for key_value in value._var_metadata.items(): + for key_value in value.get_metadata().items(): # append_fn(key_value) if key_value != next(fp_iterator): return False @@ -1197,10 +1201,10 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): raise RuntimeError(f'Expected a no update for ArrayRef but got {leaf}.') elif type(leaf) in (NoUpdate, Repeated): raise ValueError( - f"Expected a ArrayRefOutput type but got '{leaf.value}.'" + f"Expected a ArrayRefOutput type but got '{leaf}.'" ) elif type(leaf) is ArrayRefOutput: - array_ref = variablelib.new_ref(leaf.value) + array_ref = jax.new_ref(leaf.value) elif variablelib.is_array_ref(leaf): array_ref = leaf else: @@ -1224,11 +1228,14 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): value = next(leaves_iter) assert type(variabledef.array_refdef) is ArrayRefDef if isinstance(value, Variable): - value = value.copy() if copy_variables else value - inner_value = value.raw_value + copy_ref = not isinstance( + value.get_raw_value(), (NoUpdate, Repeated, ArrayRefOutput) + ) + value = value.copy(_copy_ref=copy_ref) if copy_variables else value + inner_value = value.get_raw_value() array_ref = get_mutable_array(variabledef.array_refdef, inner_value) if array_ref is not inner_value: - value.raw_value = array_ref + value.set_raw_value(array_ref) else: # if value is an array or array ref, we need call get_mutable_array # to register it in the index_ref @@ -1236,7 +1243,10 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): else: value = next(leaves_iter) if isinstance(value, Variable) and copy_variables: - value = value.copy() + copy_ref = not isinstance( + value.get_raw_value(), (NoUpdate, Repeated, ArrayRefOutput) + ) + value = value.copy(_copy_ref=copy_ref) # when idxmap is present, check if the Varable exists there # and update existing variables if it does @@ -1252,7 +1262,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): elif isinstance(value, Variable): variable.update_from_state(value) else: - variable.raw_value = value + variable.set_raw_value(value) else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one if isinstance(value, Variable): @@ -1437,12 +1447,12 @@ def _update_variable(node: Variable, value): # can happen when using standalone Variables with `grad` pass else: - if is_array_ref(node.raw_value) and ( + if is_array_ref(node.get_raw_value()) and ( isinstance(value, jax.Array) or is_array_ref(value) ): node[...] = value[...] else: - node.raw_value = value + node.set_raw_value(value, _unsafe_bypass_check=True) if isinstance(node, Variable): _update_variable(node, state) @@ -1464,7 +1474,10 @@ def _update_variable(node: Variable, value): f'type {type(node).__name__}' ) if isinstance(value, Variable): - value = value.copy() + copy_ref = not isinstance( + value.get_raw_value(), (NoUpdate, Repeated, ArrayRefOutput) + ) + value = value.copy(_copy_ref=copy_ref) node_impl.set_key(node, key, value) continue @@ -1780,7 +1793,7 @@ def flatten( # type: ignore[invalid-annotation] else: paths = None leaves = [ - variable.raw_value for variable in node_static_cache.variables + variable.get_raw_value() for variable in node_static_cache.variables ] else: graphdef, flat_state = flatten( @@ -1916,7 +1929,7 @@ def unflatten( # type: ignore[invalid-annotation] if isinstance(leaf, Variable): variable.update_from_state(leaf) else: - variable.raw_value = leaf + variable.set_raw_value(leaf) self.index_ref.update(static_cache_node.new_index_ref) else: # uncached node, create it @@ -2550,7 +2563,7 @@ def pop( >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) - >>> assert intermediates['i'].value[0].shape == (1, 3) + >>> assert intermediates['i'][0].shape == (1, 3) >>> assert not hasattr(model, 'i') Args: @@ -2609,18 +2622,17 @@ def clone(node: Node, variables: bool = True) -> Node: return merge(graphdef, state, copy=variables) -def _mutable_like(path, x): - return (isinstance(x, Variable) and x.has_ref) or variablelib.is_array_ref(x) - - -def to_arrays( +def vars_as( node: A, /, *, - only: filterlib.Filter = _mutable_like, + is_hijax: bool | None = None, + has_ref: bool | None = None, + is_mutable: bool | None = None, + only: filterlib.Filter = ..., allow_duplicates: bool = False, ) -> A: - """Converts a structure of array refs to regular arrays. + """Converts a structure of Variable to 'ref' mode. Example:: @@ -2628,18 +2640,17 @@ def to_arrays( >>> import jax >>> import jax.numpy as jnp ... - >>> node = [jax.new_ref(jnp.array(1.0)), jnp.array(2.0)] - >>> assert isinstance(node[0], jax.Ref) - ... - >>> frozen_node = nnx.to_arrays(node) - >>> assert isinstance(frozen_node[0], jax.Array) + >>> node = [nnx.Variable(jnp.array(1.0)), nnx.Variable(jnp.array(2.0))] + >>> node = nnx.as_ref_vars(node) + >>> assert node[0].mode == 'ref' + >>> assert node[1].mode == 'ref' - If the structure contains duplicate array refs, a ValueError is raised:: + If the structure contains duplicate arrays a ValueError is raised:: - >>> shared_array = jax.new_ref(jnp.array(1.0)) - >>> node = [shared_array, shared_array] + >>> shared = nnx.Variable(jnp.array(1.0)) + >>> node = [shared, shared] >>> try: - ... nnx.to_arrays(node) + ... nnx.as_ref_vars(node) ... except ValueError as e: ... print(e) Found duplicate at paths: @@ -2649,20 +2660,36 @@ def to_arrays( --- ``only`` is a `Filter `__ - that can be used to specify which array refs to freeze:: + that can be used to specify which arrays to convert to array refs. - >>> node = [jax.new_ref(jnp.array(1.0)), jax.new_ref(jnp.array(2.0))] - >>> frozen_node = nnx.to_arrays(node, only=lambda path, x: path[0] == 0) + >>> node = [nnx.Variable(jnp.array(1.0)), nnx.Variable(jnp.array(2.0))] + >>> mutable_node = nnx.as_ref_vars(node, only=lambda path, x: path[0] == 0) ... - >>> assert isinstance(frozen_node[0], jax.Array) - >>> assert isinstance(frozen_node[1], jax.Ref) + >>> assert mutable_node[0].mode == 'ref' + >>> assert mutable_node[1].mode == 'pytree' Args: - node: A structure potentially containing array refs. - only: A Filter to specify which array refs to freeze. + node: A structure potentially containing arrays. + only: A Filter to specify which arrays to convert to array refs. Returns: - A structure with the frozen arrays. + A structure with the array refs. """ + new_attrs: dict[str, bool] = {} + if is_hijax is not None: + new_attrs['is_hijax'] = is_hijax + if has_ref is not None: + new_attrs['has_ref'] = has_ref + if is_mutable is not None: + new_attrs['is_mutable'] = is_mutable + + def _different_vars(path, x): + return isinstance(x, Variable) and any( + getattr(x, attr) != value for attr, value in new_attrs.items() + ) + + only = filterlib.All(_different_vars, only) + predicate = filterlib.to_predicate(only) + if not allow_duplicates and ( all_duplicates := find_duplicates(node, only=only) ): @@ -2674,18 +2701,23 @@ def to_arrays( duplicates_strs += '\n ---' raise ValueError(f'Found duplicate at paths:{duplicates_strs}') - graphdef, mutable_state, rest = split(node, only, ...) # type: ignore[misc] - frozen_state = jax.tree.map(lambda x: x[...], mutable_state) - node = merge(graphdef, frozen_state, rest) - return node - + def _to_refs(jax_path, x): + if predicate(jax_to_nnx_path(jax_path), x): + assert isinstance(x, Variable) + variable = x.copy(**new_attrs) + return variable + return x -def _array_like(path, x): - return (isinstance(x, Variable) and not x.has_ref) or isinstance(x, jax.Array) + node = jax.tree.map_with_path( + _to_refs, node, is_leaf=lambda x: isinstance(x, Variable) + ) + return node -def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A: - """Converts a structure of arrays to array refs. +def as_ref_vars( + node: A, /, *, only: filterlib.Filter = ..., allow_duplicates: bool = False +) -> A: + """Converts a structure of Variable to 'ref' mode. Example:: @@ -2693,17 +2725,17 @@ def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A: >>> import jax >>> import jax.numpy as jnp ... - >>> node = [jnp.array(1.0), jax.new_ref(jnp.array(2.0))] - >>> mutable_node = nnx.to_refs(node) - >>> assert isinstance(mutable_node[0], jax.Ref) - >>> assert isinstance(mutable_node[1], jax.Ref) + >>> node = [nnx.Variable(jnp.array(1.0)), nnx.Variable(jnp.array(2.0))] + >>> node = nnx.as_ref_vars(node) + >>> assert node[0].mode == 'ref' + >>> assert node[1].mode == 'ref' If the structure contains duplicate arrays a ValueError is raised:: - >>> shared_array = jnp.array(1.0) - >>> node = [shared_array, shared_array] + >>> shared = nnx.Variable(jnp.array(1.0)) + >>> node = [shared, shared] >>> try: - ... nnx.to_refs(node) + ... nnx.as_ref_vars(node) ... except ValueError as e: ... print(e) Found duplicate at paths: @@ -2715,11 +2747,11 @@ def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A: ``only`` is a `Filter `__ that can be used to specify which arrays to convert to array refs. - >>> node = [jnp.array(1.0), jnp.array(2.0)] - >>> mutable_node = nnx.to_refs(node, only=lambda path, x: path[0] == 0) + >>> node = [nnx.Variable(jnp.array(1.0)), nnx.Variable(jnp.array(2.0))] + >>> mutable_node = nnx.as_ref_vars(node, only=lambda path, x: path[0] == 0) ... - >>> assert isinstance(mutable_node[0], jax.Ref) - >>> assert isinstance(mutable_node[1], jax.Array) + >>> assert mutable_node[0].mode == 'ref' + >>> assert mutable_node[1].mode == 'pytree' Args: node: A structure potentially containing arrays. @@ -2727,19 +2759,59 @@ def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A: Returns: A structure with the array refs. """ - if all_duplicates := find_duplicates(node, only=only): - duplicates_strs = '\n ---' - for node_duplicates in all_duplicates: - for path in node_duplicates: - path_str = '/'.join(map(str, path)) - duplicates_strs += f'\n {path_str}' - duplicates_strs += '\n ---' - raise ValueError(f'Found duplicate at paths:{duplicates_strs}') + return vars_as( + node, has_ref=True, only=only, allow_duplicates=allow_duplicates + ) - graphdef, frozen_state, rest = split(node, only, ...) # type: ignore[misc] - mutable_state = jax.tree.map(variablelib.new_ref, frozen_state) - node = merge(graphdef, mutable_state, rest) - return node + +def as_array_vars( + node: A, /, *, only: filterlib.Filter = ..., allow_duplicates: bool = False +) -> A: + """ """ + return vars_as( + node, has_ref=False, only=only, allow_duplicates=allow_duplicates + ) + + +def as_hijax_vars( + node: A, /, *, only: filterlib.Filter = ..., mutable: bool = True +) -> A: + """ """ + return vars_as(node, is_hijax=True, is_mutable=mutable, only=only) + + +def as_pytree_vars( + node: A, /, *, allow_duplicates: bool = False, only: filterlib.Filter = ... +) -> A: + """ """ + return vars_as( + node, is_hijax=False, allow_duplicates=allow_duplicates, only=only + ) + + +def immutable( + node: A, + /, + *, + allow_duplicates: bool = False, + only: filterlib.Filter = ..., +) -> A: + """ """ + return vars_as( + node, is_mutable=False, allow_duplicates=allow_duplicates, only=only + ) + +def mutable( + node: A, + /, + *, + allow_duplicates: bool = False, + only: filterlib.Filter = ..., +) -> A: + """ """ + return vars_as( + node, is_mutable=True, allow_duplicates=allow_duplicates, only=only + ) def pure(tree: A) -> A: @@ -2781,7 +2853,9 @@ def pure(tree: A) -> A: def _pure_fn(x): if isinstance(x, Variable): - return x.raw_value + return pure(x.get_raw_value()) + elif variablelib.is_array_ref(x): + return x[...] return x return jax.tree.map( @@ -2881,7 +2955,9 @@ def pure_caller(accessor: DelayedAccessor, *args, **kwargs): return CallableProxy(pure_caller) # type: ignore -def set_metadata(node: tp.Any, /, *, only: filterlib.Filter = Variable, **metadata: tp.Mapping[str, tp.Any]) -> None: +def set_metadata( + node: tp.Any, /, *, only: filterlib.Filter = Variable, **metadata: tp.Any +) -> None: """Sets the metadata of all :class:`Variable` objects in the given graph node in-place. Example:: diff --git a/flax/nnx/module.py b/flax/nnx/module.py index aa32a7edf..83fafd587 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -182,7 +182,7 @@ def sow( f"Expected '{name}' to be of type '{variable_type.__name__}', " f"got '{type(variable).__name__}'" ) - variable.raw_value = reduce_fn(variable.raw_value, value) + variable.set_value(reduce_fn(variable.get_value(), value)) else: reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 12d8d4f7c..cbef380db 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -368,7 +368,7 @@ def __call__( mask=mask, ) # stop_gradient only for flax_array_ref - if self.mean.has_ref or self.var.has_ref: + if self.mean._can_update or self.var._can_update: stop_gradient = jax.lax.stop_gradient else: stop_gradient = lambda x: x diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 53353cc39..59d7d227d 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -235,10 +235,10 @@ def _collect_stats( node_stats[id(node)] = stats if isinstance(node, Variable): - var_type = type(node) + var_type = node.var_type if issubclass(var_type, nnx.RngState): var_type = nnx.RngState - size_bytes = SizeBytes.from_any(node.raw_value) + size_bytes = SizeBytes.from_any(node.get_value()) if size_bytes: stats[var_type] = size_bytes @@ -355,6 +355,7 @@ def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P: return node +@jax.tree_util.register_static @dataclasses.dataclass(frozen=True, repr=False) class ArrayRepr(reprlib.Representable): shape: tp.Tuple[int, ...] @@ -507,12 +508,8 @@ def _setattr(self, name, value: tp.Any) -> None: vars(self)[name] = value def _check_value(self, key, value, new_status: AttributeStatus | None): - def _has_arrays(leaves): - return any( - isinstance(leaf, (np.ndarray, jax.Array)) - or variablelib.is_array_ref(leaf) - for leaf in leaves - ) + def _has_data(leaves): + return any(is_data(leaf) for leaf in leaves) def _get_annotations(leaves): return { @@ -547,7 +544,7 @@ def _has_visited(x): f' _.{key} = nnx.data(...)\n\n' ) - if _has_arrays(leaves): + if _has_data(leaves): # check no data in nnx.static assignments if new_status is not None: if not new_status.is_data and new_status.explicit: @@ -663,7 +660,8 @@ def __nnx_repr__(self): def to_shape_dtype(value): if isinstance(value, Variable): return value.replace( - raw_value=jax.tree.map(to_shape_dtype, value.raw_value) + value=jax.tree.map(to_shape_dtype, value.get_value()), + _copy_ref=False, ) elif variablelib.is_array_ref(value) and np.prod(value.shape) > 1: return MutableArrayRepr(value.shape, value.dtype) @@ -838,10 +836,10 @@ class Object(Pytree, pytree=False): """Base class for NNX objects that are not pytrees.""" def __init_subclass__(cls, **kwargs): - pytree = kwargs.pop('pytree', False) + pytree = kwargs.pop('immutable', False) if pytree is not False: raise ValueError( - "Object is not a pytree, but 'pytree' was explicitly set to " + "Object is not a pytree, but 'immutable' was explicitly set to " f'{pytree!r} for type {cls}.' ) super().__init_subclass__(pytree=pytree, **kwargs) diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index bfcf5afb6..2c8479b3f 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -20,10 +20,9 @@ from jax import random import jax.numpy as jnp -from flax import errors, struct +from flax import struct from flax import typing from flax.nnx import graph -from flax.nnx import variablelib from flax.nnx.nn import initializers from flax.nnx.variablelib import Variable from flax.nnx import filterlib @@ -116,10 +115,7 @@ def __init__( self.count = RngCount(count, tag=tag) def __call__(self) -> jax.Array: - if not self.count.has_ref and not self.count._trace_state.is_valid(): - raise errors.TraceContextError( - f'Cannot mutate {type(self).__name__} from a different trace level' - ) + self.count._check_can_update() key = random.fold_in(self.key[...], self.count[...]) self.count[...] += 1 return key @@ -382,7 +378,7 @@ def __init__( for tag, key in rngs.items(): if isinstance(key, RngStream): - key = key.key[...] + key = key.key.get_value() stream = RngStream( key=key, tag=tag, @@ -826,14 +822,11 @@ def split_rngs_wrapper(*args, **kwargs): and predicate((*path, 'count'), stream.count) ): key = stream() - backups.append((stream, stream.key.raw_value, stream.count.raw_value)) + backups.append((stream, stream.key[...], stream.count[...])) key = random.split(key, splits) if squeeze: key = key[0] - if variablelib.is_array_ref(stream.key.raw_value): - stream.key.raw_value = variablelib.new_ref(key) # type: ignore[assignment] - else: - stream.key.value = key + stream.key.set_value(key) if squeeze: counts_shape = stream.count.shape elif isinstance(splits, int): @@ -841,11 +834,7 @@ def split_rngs_wrapper(*args, **kwargs): else: counts_shape = (*splits, *stream.count.shape) - count = jnp.zeros(counts_shape, dtype=jnp.uint32) - if variablelib.is_array_ref(stream.count.raw_value): - stream.count.raw_value = variablelib.new_ref(count) # type: ignore[assignment] - else: - stream.count.value = count + stream.count.set_value(jnp.zeros(counts_shape, dtype=jnp.uint32)) return SplitBackups(backups) @@ -992,10 +981,10 @@ def fork_rngs_wrapper(*args, **kwargs): ): forked_stream = stream.fork(split=splits) # backup the original stream state - backups.append((stream, stream.key.raw_value, stream.count.raw_value)) + backups.append((stream, stream.key[...], stream.count[...])) # apply the forked key and count to the original stream - stream.key.raw_value = forked_stream.key.raw_value - stream.count.raw_value = forked_stream.count.raw_value + stream.key.set_value(forked_stream.key.get_value()) + stream.count.set_value(forked_stream.count.get_value()) return SplitBackups(backups) @@ -1004,7 +993,7 @@ def backup_keys(node: tp.Any, /): backups: list[StreamBackup] = [] for _, stream in graph.iter_graph(node): if isinstance(stream, RngStream): - backups.append((stream, stream.key.raw_value)) + backups.append((stream, stream.key[...])) return backups def _scalars_only( @@ -1090,13 +1079,13 @@ def reseed( if stream.key.tag in stream_keys: key = rngs[stream.key.tag]() key = policy(path, key, stream.key.shape) - stream.key.value = key - stream.count.value = jnp.zeros(key.shape, dtype=jnp.uint32) + stream.key.set_value(key) + stream.count.set_value(jnp.zeros(key.shape, dtype=jnp.uint32)) def restore_rngs(backups: tp.Iterable[StreamBackup], /): for backup in backups: stream = backup[0] - stream.key.raw_value = backup[1] + stream.key.set_value(backup[1]) if len(backup) == 3: - stream.count.raw_value = backup[2] # count + stream.count.set_value(backup[2]) # count diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 8d0e5bf5c..04e5cca31 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -248,7 +248,7 @@ def __init__( super().__setattr__('_mapping', _mapping) @property - def raw_mapping(self) -> tp.Mapping[K, tp.Mapping[K, tp.Any] | V]: + def raw_mapping(self) -> dict[K, tp.Mapping[K, tp.Any] | V]: return self._mapping # type: ignore def __contains__(self, key) -> bool: @@ -521,7 +521,7 @@ def to_pure_dict( """ # Works for nnx.Variable if extract_fn is None: - extract_fn = lambda x: x.value if isinstance(x, variablelib.Variable) else x + extract_fn = lambda x: x.get_value() if isinstance(x, variablelib.Variable) else x flat_values = {k: extract_fn(x) for k, x in to_flat_state(state)} return traversals.unflatten_mapping(flat_values) @@ -831,6 +831,6 @@ def create_path_filters(state: State): value_paths: dict[tp.Any, set[PathParts]] = {} for path, value in flat_state: if isinstance(value, variablelib.Variable): - value = value.raw_value + value = value.get_value() value_paths.setdefault(value, set()).add(path) return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py index 20b6a2689..f597d5af2 100644 --- a/flax/nnx/summary.py +++ b/flax/nnx/summary.py @@ -96,7 +96,7 @@ def _collect_stats( var_type = type(value) if issubclass(var_type, nnx.RngState): var_type = nnx.RngState - size_bytes = SizeBytes.from_any(value.value) + size_bytes = SizeBytes.from_any(value.get_value()) if var_type in stats: stats[var_type] += size_bytes else: @@ -455,11 +455,15 @@ def do_vjp(*args, **kwargs): for var_type in variable_types: attributes = {} + variable: variablelib.Variable for name, variable in node_info.variable_groups[var_type].items(): - value = variable.value + value = variable.get_value() value_repr = _render_array(value) if _has_shape_dtype(value) else '' metadata = variable.get_metadata() - + metadata.pop('is_hijax') + metadata.pop('has_ref') + metadata.pop('is_mutable') + metadata.pop('eager_sharding', None) if metadata: attributes[name] = { 'value': value_repr, diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index a2197c919..892f27151 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -28,9 +28,6 @@ M = tp.TypeVar('M', bound=nnx.Module) F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) -# TODO: add tests and docstrings - - class OptState(Variable): """Any optimizer state""" @@ -52,7 +49,7 @@ class OptVariable(OptState): def to_opt_state(tree): def _to_opt_state(x): if isinstance(x, Variable): - opt_state = OptVariable(x.value, **x.get_metadata()) # type: ignore + opt_state = OptVariable(x.get_value(), **x.get_metadata()) # type: ignore else: opt_state = OptArray(x) return opt_state @@ -209,10 +206,10 @@ def update(self, model: M, grads, /, **kwargs): **kwargs: additional keyword arguments passed to the tx.update, to support ``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``. """ - param_arrays = nnx.to_arrays(nnx.pure(nnx.state(model, self.wrt))) - grad_arrays = nnx.to_arrays(nnx.pure(nnx.state(grads, self.wrt))) - opt_state_arrays = nnx.to_arrays(nnx.pure(self.opt_state)) - kwargs_arrays = nnx.to_arrays(nnx.pure(kwargs)) + param_arrays = nnx.pure(nnx.state(model, self.wrt)) + grad_arrays = nnx.pure(nnx.state(grads, self.wrt)) + opt_state_arrays = nnx.pure(self.opt_state) + kwargs_arrays = nnx.pure(kwargs) updates, new_opt_state = self.tx.update( grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 61bcfa0bb..fe531ef3f 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -28,7 +28,7 @@ statelib, variablelib, ) -from flax.typing import MISSING, Missing +from flax.typing import MISSING, Missing, PathParts F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) P = tp.ParamSpec('P') @@ -69,7 +69,7 @@ def shardings(self) -> tuple[tp.Any, ...]: return self._shardings def map_prefix( - self, path: variablelib.PathParts, variable: variablelib.Variable + self, path: PathParts, variable: variablelib.Variable ) -> tp.Any: for filter, sharding in zip(self.filters, self.shardings): predicate = filterlib.to_predicate(filter) diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 54027c6ff..2985e32aa 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -25,6 +25,7 @@ from flax.nnx import ( extract, graph, + variablelib, ) from flax.nnx.module import Module from flax.nnx.proxy_caller import ( @@ -125,6 +126,60 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): # ------------------------------- # simple transforms # ------------------------------- +@dataclasses.dataclass(frozen=True) +class ValueMetadata: + var_type: type[variablelib.Variable] + value: tp.Any + metadata: dict[str, tp.Any] + + +def _flatten_value_metadata( + value_metadata: tp.Union[tp.Any, ValueMetadata], +): + metadata = tuple(sorted(value_metadata.metadata.items())) + return (value_metadata.value,), (value_metadata.var_type, metadata) + + +def _unflatten_value_metadata(aux_data, children): + var_type, metadata_items = aux_data + metadata = dict(metadata_items) + return ValueMetadata(var_type=var_type, value=children[0], metadata=metadata) + + +jax.tree_util.register_pytree_node( + ValueMetadata, + _flatten_value_metadata, + _unflatten_value_metadata, +) + + +def _to_value_metadata(node): + def to_value_metadata(x): + if isinstance(x, variablelib.Variable): + value = x.get_raw_value() + if variablelib.is_array_ref(value): + value = value[...] + metadata = x.get_metadata() + return ValueMetadata(var_type=x.var_type, value=value, metadata=metadata) + return x + + return jax.tree.map( + to_value_metadata, + node, + is_leaf=lambda x: isinstance(x, variablelib.Variable), + ) + + +def _to_variable(node): + def to_variable(x): + if isinstance(x, ValueMetadata): + var = x.var_type._new(x.value, x.metadata) + return var + return x + + return jax.tree.map( + to_variable, node, is_leaf=lambda x: isinstance(x, ValueMetadata) + ) def eval_shape( @@ -146,10 +201,10 @@ def eval_shape( def _eval_shape_fn(*args, **kwargs): args, kwargs = extract.from_tree((args, kwargs)) out = f(*args, **kwargs) - return graph.to_arrays(extract.to_tree(out), allow_duplicates=True) + return _to_value_metadata(extract.to_tree(out)) out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) - return extract.from_tree(out) + return extract.from_tree(_to_variable(out)) @dataclasses.dataclass(eq=False) class CheckifyFn: diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 6d4c358a1..b516fc9aa 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -14,29 +14,36 @@ # pytype: skip-file from __future__ import annotations -import contextlib import dataclasses import functools from functools import partial import threading import typing as tp from typing import Any +import warnings + from flax import config +from jax._src import hijax +from jax._src import core as jax_core +from jax._src import effects +from jax._src import ad_util +import itertools as it import jax import treescope # type: ignore[import-untyped] from flax import errors from flax.core import spmd as core_spmd -from flax.nnx import filterlib, reprlib, tracers, visualization -from flax.typing import MISSING, Missing, PathParts, SizeBytes +from flax.nnx import reprlib, tracers, visualization +from flax.typing import MISSING, Missing, SizeBytes import jax.tree_util as jtu -import jax.numpy as jnp from jax._src.state.types import AbstractRef A = tp.TypeVar('A') B = tp.TypeVar('B') +C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +P = tp.TypeVar('P', bound=property) V = tp.TypeVar('V', bound='Variable[Any]') GetValueHook = tp.Callable[['Variable[A]', A], A] SetValueHook = tp.Callable[['Variable[A]', A], A] @@ -50,109 +57,71 @@ # The following ensures we avoid an ImportError or DeprecationWarning. if hasattr(jax, 'new_ref') and hasattr(jax, 'Ref'): # JAX v0.7.2 or newer - from jax import new_ref from jax import Ref elif hasattr(jax, 'array_ref') and hasattr(jax, 'ArrayRef'): # JAX v0.7.1 - from jax import array_ref as new_ref # type: ignore[import-untyped] from jax import ArrayRef as Ref # type: ignore[import-untyped] else: # JAX v0.7.0 or older - from jax.experimental import mutable_array as new_ref from jax.experimental import MutableArray as Ref @dataclasses.dataclass class VariableContext(threading.local): - mutable_variable_stack: list[bool] = dataclasses.field(default_factory=list) + variable_hijax_stack: list[bool] = dataclasses.field(default_factory=list) + variable_ref_stack: list[bool] = dataclasses.field( + default_factory=lambda: [False] + ) VARIABLE_CONTEXT = VariableContext() -def using_refs() -> bool: - """Returns whether Variables are using ArrayRefs by default. +def using_hijax(): + """ """ + if VARIABLE_CONTEXT.variable_hijax_stack: + return VARIABLE_CONTEXT.variable_hijax_stack[-1] - Example:: + return config.flax_hijax_variable - >>> from flax import nnx - ... - >>> nnx.using_refs() - False - >>> nnx.use_refs(True) - <...> - >>> nnx.using_refs() - True - >>> nnx.use_refs(False) - <...> - >>> nnx.using_refs() - False - - - Returns: - A boolean indicating if Variables are using ArrayRefs by default. - """ - if VARIABLE_CONTEXT.mutable_variable_stack: - return VARIABLE_CONTEXT.mutable_variable_stack[-1] + +def use_hijax(value: bool, /): + """ """ + if VARIABLE_CONTEXT.variable_hijax_stack: + prev_value = VARIABLE_CONTEXT.variable_hijax_stack[-1] + VARIABLE_CONTEXT.variable_hijax_stack[-1] = value else: - return config.flax_array_ref + prev_value = None + VARIABLE_CONTEXT.variable_hijax_stack.append(value) + return UseHijaxContext(prev_value, value) -def use_refs(value: bool, /): - """Sets whether Variables should use ArrayRefs by default or not. +class UseHijaxContext: + def __init__(self, prev_value: bool | None, new_value: bool): + self.prev_value: bool | None = prev_value + self.new_value: bool = new_value - Example usage:: + def __enter__(self): + if self.prev_value is not None: + VARIABLE_CONTEXT.variable_hijax_stack.insert(-1, self.prev_value) - >>> from flax import nnx - >>> # Use ArrayRefs by default - >>> nnx.use_refs(True) - <...> - >>> # Variable will now use ArrayRefs - >>> v = nnx.Variable(jax.numpy.ones((2, 3))) - >>> v.has_ref - True - >>> v.raw_value - Ref(...) - >>> nnx.use_refs(False) - <...> - - It can also be used as a context manager to temporarily - change the default behavior for a block of code:: - - >>> nnx.use_refs(False) - <...> - >>> with nnx.use_refs(True): - ... v = nnx.Variable(jax.numpy.ones((2, 3))) - ... v.has_ref - True - >>> # it will reset outside - >>> v = nnx.Variable(jax.numpy.ones((2, 3))) - >>> v.has_ref - False - - Args: - value: A boolean indicating if Variables should use ArrayRefs by default. - - Returns: - A context manager that resets the context to the previous value. - """ - # prev_value = VARIABLE_CONTEXT.mutable_variable_stack[-1] if VARIABLE_CONTEXT.mutable_variable_stack else None - # VARIABLE_CONTEXT.mutable_variable_stack.append(value) - if VARIABLE_CONTEXT.mutable_variable_stack: - prev_value = VARIABLE_CONTEXT.mutable_variable_stack[-1] - VARIABLE_CONTEXT.mutable_variable_stack[-1] = value - else: - prev_value = None - VARIABLE_CONTEXT.mutable_variable_stack.append(value) - return _clean_mutable_arrays_context(prev_value) + def __exit__(self, exc_type, exc_value, traceback): + VARIABLE_CONTEXT.variable_hijax_stack.pop() + + def __call__(self, f: F) -> F: + # undo eager stack change + VARIABLE_CONTEXT.variable_hijax_stack.pop() + if self.prev_value is not None: + VARIABLE_CONTEXT.variable_hijax_stack.append(self.prev_value) -@contextlib.contextmanager -def _clean_mutable_arrays_context(prev_value: bool | None): - if prev_value is not None: - VARIABLE_CONTEXT.mutable_variable_stack.insert(-1, prev_value) - try: - yield - finally: - VARIABLE_CONTEXT.mutable_variable_stack.pop() + @functools.wraps(f) + def use_hijax_wrapper(*args, **kwargs): + VARIABLE_CONTEXT.variable_hijax_stack.append(self.new_value) + try: + return f(*args, **kwargs) + finally: + VARIABLE_CONTEXT.variable_hijax_stack.pop() + + return use_hijax_wrapper # type: ignore[return-value] def is_array_ref(x) -> tp.TypeGuard[Ref]: @@ -172,12 +141,694 @@ class VariableMetadata(tp.Generic[A]): metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) +PyTreeDef = tp.Any + +#--------------------------------- +# hijax +#--------------------------------- + +def _new_hijax_variable(var_type: type[Variable]) -> HijaxVariable: + variable = var_type._new(None, {}) + (), treedef = jax.tree.flatten(variable) + return new_variable_p.bind(treedef=treedef, var_type=var_type) + + +def _get_hijax_state(hijax_var) -> Variable: + tys: VariableQDD = jax_core.cur_qdd(hijax_var) + leaf_vals = get_variable_p.bind(hijax_var, avals=tuple(tys.leaf_avals)) + variable = jax.tree.unflatten(tys.treedef, leaf_vals) + return variable + + +def _set_hijax_state(hijax_var, variable: Variable): + leaves, treedef = jax.tree.flatten(variable) + set_variable_p.bind( + hijax_var, *leaves, treedef=treedef, var_type=type(variable) + ) + + +def _new_hijax_from_variable(variable: Variable) -> HijaxVariable: + hijax_var = _new_hijax_variable(type(variable)) + _set_hijax_state(hijax_var, variable) + return hijax_var + + + +@dataclasses.dataclass(frozen=True) +class VariableQDD(jax_core.QuasiDynamicData): + leaf_avals: tuple[jax_core.AbstractValue, ...] + treedef: PyTreeDef + + def to_tangent_qdd(self): + leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals) + return VariableQDD(leaf_avals, self.treedef) + + def normalize(self): + leaf_types = tuple(a.normalize() for a in self.leaf_avals) + return VariableQDD(leaf_types, self.treedef) + + +class VariableEffect(effects.Effect): ... + + +variable_effect = VariableEffect() +effects.control_flow_allowed_effects.add_type(VariableEffect) + + +class NewVariable(hijax.HiPrimitive): + def is_high(self, *, treedef, var_type) -> bool: + return True # type: ignore + + def abstract_eval(self, *, treedef, var_type: type[Variable]): + variable = var_type._new(None, {}) + leaves, treedef = jax.tree.flatten(variable) + qdd = VariableQDD(tuple(leaves), treedef) + return jax_core.AvalQDD(AbstractVariable(var_type), qdd), {variable_effect} + + def to_lojax(self, *, treedef, var_type: type[Variable]): + return HijaxVariable._new(None, {}, var_type) + + def jvp(_, primals, tangents, *, treedef, var_type): + raise NotImplementedError('jvp not implemented for NewHijaxVariable') + + def transpose(_, *args, treedef, var_type): + raise NotImplementedError('transpose not implemented for NewHijaxVariable') + + +new_variable_p = NewVariable(f'new_variable') + + +class SetVariable(hijax.HiPrimitive): + multiple_results = True + + def is_high(self, *leaf_avals, treedef, var_type) -> bool: + return True # type: ignore + + # TODO: upstream this to Box + def impl(self, hijax_var: HijaxVariable, *leaves, treedef, var_type): + variable: Variable = jax.tree.unflatten(treedef, leaves) + object.__setattr__(hijax_var, '_raw_value', variable._raw_value) + object.__setattr__(hijax_var, '_metadata', variable._var_metadata) + return [] + + def abstract_eval(self, hijax_var_type, *leaf_avals, treedef, var_type): + hijax_var_type.mutable_qdd.update(VariableQDD(leaf_avals, treedef)) + return [], {variable_effect} # TODO better typechecking... + + def to_lojax(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): + variable: Variable = jax.tree.unflatten(treedef, leaves) + object.__setattr__(hijax_var, '_raw_value', variable._raw_value) + object.__setattr__(hijax_var, '_metadata', variable._var_metadata) + return [] + + def jvp(_, primals, tangents, *, treedef, var_type): + variable: Variable + variable, *vals = primals + variable_dot: Variable + variable_dot, *val_dots = tangents + if type(variable_dot._raw_value) is ad_util.Zero: + raise Exception( + "can't differentiate Variable._set operation, " + 'did you forget jax.lax.stop_gradient?' + ) + set_variable_p.bind( + variable, *vals, treedef=treedef, var_type=type(variable) + ) + set_variable_p.bind( + variable_dot, *val_dots, treedef=treedef, var_type=type(variable_dot) + ) + return [], [] + + def transpose(_, *args, treedef, var_type): + raise NotImplementedError('transpose not implemented for SetHijaxVariable') + + +set_variable_p = SetVariable(f'set_variable') + + +class GetVariable(hijax.HiPrimitive): + multiple_results = True + + def abstract_eval(self, abstract_var, *, avals): + return avals, {variable_effect} + + def to_lojax(_, hijax_var: HijaxVariable, *, avals): + return jax.tree.leaves(hijax_var._raw_value) + + def jvp(_, primals, tangents, *, avals): + (box,), (variable_dot,) = primals, tangents + return ( + get_variable_p.bind(box, avals=avals), + get_variable_p.bind( + variable_dot, avals=tuple(a.to_tangent_aval() for a in avals) + ), + ) + + def transpose(_, *args): + raise NotImplementedError('transpose not implemented for GetHijaxVariable') + + +get_variable_p = GetVariable(f'get_variable') + + +# --------------------------------- +# HijaxVariable +# --------------------------------- +def _variable_has_changed(old: Variable, new: Variable) -> bool: + old_structure = jax.tree.structure(old) + new_structure = jax.tree.structure(new) + if old_structure != new_structure: # type: ignore[operator] + return True + old_leaves = jax.tree.leaves(old) + new_leaves = jax.tree.leaves(new) + return any(o is not n for o, n in zip(old_leaves, new_leaves)) + + +def _as_hijax_property(name: str, *, get: bool, set: bool) -> property: + """Creates a property that operates on the hijax type.""" + + def _getter_wrapper(hijax_var): + variable = _get_hijax_state(hijax_var) + old_state = jax.tree.map(lambda x: x, variable) + out = getattr(variable, name) + if _variable_has_changed(old_state, variable): + _set_hijax_state(hijax_var, variable) + return out + + def _setter_wrapper(hijax_var, value): + variable = _get_hijax_state(hijax_var) + setattr(variable, name, value) + _set_hijax_state(hijax_var, variable) + + _hijax_property = property( + fget=_getter_wrapper if get else None, + fset=_setter_wrapper if set else None, + ) + return _hijax_property # type: ignore[return] + + +def _as_aval_property(p: property) -> jax_core.aval_property: + """Wraps a property `p` operate on the aval type.""" + _aval_property = jax_core.aval_property(fget=p.fget) + return _aval_property # type: ignore[return] + + +def _as_hijax_attribute(name: str) -> property: + """Creates a property that operates on the hijax type.""" + + def _getter_wrapper(hijax_var): + variable = _get_hijax_state(hijax_var) + old_state = jax.tree.map(lambda x: x, variable) + out = getattr(variable, name) + if _variable_has_changed(old_state, variable): + _set_hijax_state(hijax_var, variable) + return out + + _getter_wrapper.__name__ = name + _hijax_property = property(fget=_getter_wrapper) + + return _hijax_property # type: ignore[return] + + +def _as_hijax_method(name: str) -> tp.Any: + """Creates a method that operates on the hijax type.""" + + def hijax_method_wrapper(hijax_var, *args, **kwargs): + variable = _get_hijax_state(hijax_var) + old_state = jax.tree.map(lambda x: x, variable) + method = getattr(variable, name) + out = method(*args, **kwargs) + if _variable_has_changed(old_state, variable): + _set_hijax_state(hijax_var, variable) + return out + + hijax_method_wrapper.__name__ = name + + return hijax_method_wrapper + + +def _as_tracer_method(name: str): + def op(self, hijax_var, *args, **kwargs): + variable = _get_hijax_state(hijax_var) + old_state = jax.tree.map(lambda x: x, variable) + out = getattr(variable, name)(*args, **kwargs) + if _variable_has_changed(old_state, variable): + _set_hijax_state(hijax_var, variable) + return out + + op.__name__ = name + return op + + +class HijaxVariableMeta(type): + def __instancecheck__(self, instance): + if super().__instancecheck__(instance): + return True + + if isinstance(instance, jax_core.Tracer): + ty = jax_core.typeof(instance) + return isinstance(ty, AbstractVariable) + return False + + +class HijaxVariable( + tp.Generic[A], reprlib.Representable, metaclass=HijaxVariableMeta +): # type: ignore + __slots__ = ('_raw_value', '_metadata', '_var_type') + _raw_value: A + _metadata: dict[str, tp.Any] + _var_type: type[Variable[tp.Any]] + + @classmethod + def _new( + cls, + value, + metadata: dict[str, tp.Any], + var_type: type[Variable[A]], + ): + hijax_var = object.__new__(cls) + object.__setattr__(hijax_var, '_raw_value', value) + object.__setattr__(hijax_var, '_metadata', metadata) + object.__setattr__(hijax_var, '_var_type', var_type) + return hijax_var + + __init__ = _as_hijax_method('__init__') + + @property + def value(self) -> A: + raise NotImplementedError( + 'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n' + ' variable[...]\n\n' + 'For other Variable types use:\n\n' + ' variable.get_value()\n' + ) + + @value.setter + def value(self, new_value: A): + raise NotImplementedError( + 'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n' + ' variable[...] = new_value\n\n' + 'For other Variable types use:\n\n' + ' variable.set_value(new_value)\n' + ) + + @property + def var_type(self) -> type[Variable[A]]: + return self._var_type + + _trace_state = _as_hijax_property('_trace_state', get=True, set=False) + _can_update = _as_hijax_property('_can_update', get=True, set=False) + _check_can_update = _as_hijax_method('_check_can_update') + __getattr__ = _as_hijax_method('__getattr__') + __setattr__ = _as_hijax_method('__setattr__') + __delattr__ = _as_hijax_method('__delattr__') + type = _as_hijax_property('type', get=True, set=False) + is_hijax = _as_hijax_property('is_hijax', get=True, set=False) + has_ref = _as_hijax_property('has_ref', get=True, set=False) + is_mutable = _as_hijax_property('is_mutable', get=True, set=False) + get_metadata = _as_hijax_method('get_metadata') + set_metadata = _as_hijax_method('set_metadata') + + def copy_from(self, other: Variable[A] | HijaxVariable[A]) -> None: + if isinstance(other, HijaxVariable): + other = _get_hijax_state(other) + variable = _get_hijax_state(self) + variable.copy_from(other) # type: ignore[arg-type] + _set_hijax_state(self, variable) + + def update_from_state(self, variable_state: Variable[A] | HijaxVariable[A]): + if isinstance(variable_state, HijaxVariable): + variable_state = _get_hijax_state(variable_state) + variable = _get_hijax_state(self) + variable.update_from_state(variable_state) # type: ignore[arg-type] + _set_hijax_state(self, variable) + + get_raw_value = _as_hijax_method('get_raw_value') + set_raw_value = _as_hijax_method('set_raw_value') + set_value = _as_hijax_method('set_value') + get_value = _as_hijax_method('get_value') + create_value = _as_hijax_method('create_value') + set_raw_value = _as_hijax_method('set_raw_value') + add_axis = _as_hijax_method('add_axis') + remove_axis = _as_hijax_method('remove_axis') + copy = _as_hijax_method('copy') + replace = _as_hijax_method('replace') + to_state = _as_hijax_method('to_state') + + @classmethod + def from_metadata(cls, value: A, metadata: dict[str, tp.Any]): + return cls._var_type.from_metadata(value, metadata) # type: ignore[misc] + + __nnx_repr__ = _as_hijax_method('__nnx_repr__') + __treescope_repr__ = _as_hijax_method('__treescope_repr__') + + # -------------------------------------------- + # proxy methods + # -------------------------------------------- + __jax_array__ = _as_hijax_method('__jax_array__') + __getitem__ = _as_hijax_method('__getitem__') + __setitem__ = _as_hijax_method('__setitem__') + __delitem__ = _as_hijax_method('__delitem__') + __call__ = _as_hijax_method('__call__') + __len__ = _as_hijax_method('__len__') + __iter__ = _as_hijax_method('__iter__') + __contains__ = _as_hijax_method('__contains__') + __add__ = _as_hijax_method('__add__') + __sub__ = _as_hijax_method('__sub__') + __mul__ = _as_hijax_method('__mul__') + __matmul__ = _as_hijax_method('__matmul__') + __truediv__ = _as_hijax_method('__truediv__') + __floordiv__ = _as_hijax_method('__floordiv__') + __mod__ = _as_hijax_method('__mod__') + __divmod__ = _as_hijax_method('__divmod__') + __pow__ = _as_hijax_method('__pow__') + __lshift__ = _as_hijax_method('__lshift__') + __rshift__ = _as_hijax_method('__rshift__') + __and__ = _as_hijax_method('__and__') + __xor__ = _as_hijax_method('__xor__') + __or__ = _as_hijax_method('__or__') + __radd__ = _as_hijax_method('__radd__') + __rsub__ = _as_hijax_method('__rsub__') + __rmul__ = _as_hijax_method('__rmul__') + __rmatmul__ = _as_hijax_method('__rmatmul__') + __rtruediv__ = _as_hijax_method('__rtruediv__') + __rfloordiv__ = _as_hijax_method('__rfloordiv__') + __rmod__ = _as_hijax_method('__rmod__') + __rdivmod__ = _as_hijax_method('__rdivmod__') + __rpow__ = _as_hijax_method('__rpow__') + __rlshift__ = _as_hijax_method('__rlshift__') + __rrshift__ = _as_hijax_method('__rrshift__') + __rand__ = _as_hijax_method('__rand__') + __rxor__ = _as_hijax_method('__rxor__') + __ror__ = _as_hijax_method('__ror__') + __iadd__ = _as_hijax_method('__iadd__') + __isub__ = _as_hijax_method('__isub__') + __imul__ = _as_hijax_method('__imul__') + __imatmul__ = _as_hijax_method('__imatmul__') + __itruediv__ = _as_hijax_method('__itruediv__') + __ifloordiv__ = _as_hijax_method('__ifloordiv__') + __imod__ = _as_hijax_method('__imod__') + __ipow__ = _as_hijax_method('__ipow__') + __ilshift__ = _as_hijax_method('__ilshift__') + __irshift__ = _as_hijax_method('__irshift__') + __iand__ = _as_hijax_method('__iand__') + __ixor__ = _as_hijax_method('__ixor__') + __ior__ = _as_hijax_method('__ior__') + __neg__ = _as_hijax_method('__neg__') + __pos__ = _as_hijax_method('__pos__') + __abs__ = _as_hijax_method('__abs__') + __invert__ = _as_hijax_method('__invert__') + __complex__ = _as_hijax_method('__complex__') + __int__ = _as_hijax_method('__int__') + __float__ = _as_hijax_method('__float__') + __index__ = _as_hijax_method('__index__') + __round__ = _as_hijax_method('__round__') + __trunc__ = _as_hijax_method('__trunc__') + __floor__ = _as_hijax_method('__floor__') + __ceil__ = _as_hijax_method('__ceil__') + + # -------------------------------------------- + # hijax interface + # -------------------------------------------- + + def cur_qdd(self): + return self.type_state() + + @property + def ty(self): + return AbstractVariable(self._var_type) + + def type_state(self): + variable = self._var_type._new(self._raw_value, self._metadata) + leaves, treedef = jax.tree.flatten(variable) + leaf_avals = tuple(map(jax_core.typeof, leaves)) + return VariableQDD(leaf_avals, treedef) + + +hijax.register_hitype(HijaxVariable, lambda b: b.ty) + + +# --------------------------------- +# AbstractVariable +# --------------------------------- +class AbstractVariable(tp.Generic[A], hijax.MutableHiType): + __slots__ = ['_var_type'] + _var_type: type[Variable[A]] + # forwarded to value + var_type = jax_core.aval_property(lambda self: self.aval._var_type) + is_hijax = _as_aval_property(HijaxVariable.is_hijax) + has_ref = _as_aval_property(HijaxVariable.has_ref) + is_mutable = _as_aval_property(HijaxVariable.is_mutable) + _trace_state = _as_aval_property(HijaxVariable._trace_state) + _can_update = _as_aval_property(HijaxVariable._can_update) + _check_can_update = jax_core.aval_method(HijaxVariable._check_can_update) + + def __init__(self, var_type: type[Variable[A]]): + object.__setattr__(self, '_var_type', var_type) + + @property + def dtype(self): + raise AttributeError + + @property + def ndim(self): + raise AttributeError + + @property + def size(self): + raise AttributeError + + @property + def shape(self): + raise AttributeError + + def __getattr__(self, name: str): + # Forward unknown attributes to the value + if hasattr(AbstractVariable, name): + raise AttributeError + if name.startswith('_'): + raise AttributeError + return _as_aval_property(_as_hijax_attribute(name)) + + # __setattr__ supported via __getattr__ + # __delattr__ CURRENTLY NOT SUPPORTED + type = _as_aval_property(HijaxVariable.type) + get_metadata = jax_core.aval_method(HijaxVariable.get_metadata) + set_metadata = jax_core.aval_method(HijaxVariable.set_metadata) + copy_from = jax_core.aval_method(HijaxVariable.copy_from) + update_from_state = jax_core.aval_method(HijaxVariable.update_from_state) + get_raw_value = jax_core.aval_method(HijaxVariable.get_raw_value) + set_raw_value = jax_core.aval_method(HijaxVariable.set_raw_value) + set_value = jax_core.aval_method(HijaxVariable.set_value) + get_value = jax_core.aval_method(HijaxVariable.get_value) + create_value = jax_core.aval_method(HijaxVariable.create_value) + set_raw_value = jax_core.aval_method(HijaxVariable.set_raw_value) + add_axis = jax_core.aval_method(HijaxVariable.add_axis) + remove_axis = jax_core.aval_method(HijaxVariable.remove_axis) + replace = jax_core.aval_method(HijaxVariable.replace) + + @jax_core.aval_method + def from_metadata(self, value, metadata: dict[str, tp.Any]): + aval: AbstractVariable = self.aval # type: ignore + variable = aval._var_type.from_metadata(value, metadata) + return variable + + copy = jax_core.aval_method(HijaxVariable.copy) + replace = jax_core.aval_method(HijaxVariable.replace) + to_state = jax_core.aval_method(HijaxVariable.to_state) + + def __str__(self): + return f'{self._var_type.__name__}()' + + def __repr__(self): + return f'{self._var_type.__name__}()' + + @jax_core.aval_method + def __treescope_repr__(self, path, subtree_renderer): + raise NotImplementedError + + # --------------------------------- + # proxy methods + # --------------------------------- + __jax_array__ = jax_core.aval_method(HijaxVariable.__jax_array__) + _getitem = _as_tracer_method('__getitem__') + _setitem = _as_tracer_method('__setitem__') + # __delitem__ CURRENTLY NOT SUPPORTED + # __call__ CURRENTLY NOT SUPPORTED + _len = _as_tracer_method('__len__') + _iter = _as_tracer_method('__iter__') + # __contains__ CURRENTLY NOT SUPPORTED + _add = _as_tracer_method('__add__') + _sub = _as_tracer_method('__sub__') + _mul = _as_tracer_method('__mul__') + _matmul = _as_tracer_method('__matmul__') + _truediv = _as_tracer_method('__truediv__') + _floordiv = _as_tracer_method('__floordiv__') + _mod = _as_tracer_method('__mod__') + _divmod = _as_tracer_method('__divmod__') + _pow = _as_tracer_method('__pow__') + _lshift = _as_tracer_method('__lshift__') + _rshift = _as_tracer_method('__rshift__') + _and = _as_tracer_method('__and__') + _xor = _as_tracer_method('__xor__') + _or = _as_tracer_method('__or__') + _radd = _as_tracer_method('__radd__') + _rsub = _as_tracer_method('__rsub__') + _rmul = _as_tracer_method('__rmul__') + _rmatmul = _as_tracer_method('__rmatmul__') + _rtruediv = _as_tracer_method('__rtruediv__') + _rfloordiv = _as_tracer_method('__rfloordiv__') + _rmod = _as_tracer_method('__rmod__') + _rdivmod = _as_tracer_method('__rdivmod__') + _rpow = _as_tracer_method('__rpow__') + _rlshift = _as_tracer_method('__rlshift__') + _rrshift = _as_tracer_method('__rrshift__') + _rand = _as_tracer_method('__rand__') + _rxor = _as_tracer_method('__rxor__') + _ror = _as_tracer_method('__ror__') + # _iadd CURRENTLY NOT SUPPORTED + # _isub CURRENTLY NOT SUPPORTED + # _imul CURRENTLY NOT SUPPORTED + # _imatmul CURRENTLY NOT SUPPORTED + # _itruediv CURRENTLY NOT SUPPORTED + # _ifloordiv CURRENTLY NOT SUPPORTED + # _imod CURRENTLY NOT SUPPORTED + # _ipow CURRENTLY NOT SUPPORTED + # _ilshift CURRENTLY NOT SUPPORTED + # _irshift CURRENTLY NOT SUPPORTED + # _iand CURRENTLY NOT SUPPORTED + # _ixor CURRENTLY NOT SUPPORTED + # _ior CURRENTLY NOT SUPPORTED + _neg = _as_tracer_method('__neg__') + _pos = _as_tracer_method('__pos__') + _abs = _as_tracer_method('__abs__') + _invert = _as_tracer_method('__invert__') + _complex = _as_tracer_method('__complex__') + _int = _as_tracer_method('__int__') + _float = _as_tracer_method('__float__') + _index = _as_tracer_method('__index__') + _round = _as_tracer_method('__round__') + _trunc = _as_tracer_method('__trunc__') + _floor = _as_tracer_method('__floor__') + _ceil = _as_tracer_method('__ceil__') + + # -------------------------------- + # hijax interface + # -------------------------------- + has_qdd = True + + def __hash__(self): + return hash(AbstractVariable) + + def __eq__(self, other): + return isinstance(other, AbstractVariable) + + def str_short(self, short_dtypes=False, **_) -> str: # type: ignore + return f'{self._var_type.__name__}()' + + # mutable interface + def lo_ty_qdd(self, variable_state: VariableQDD) -> list: # type: ignore + return [lo_ty for t in variable_state.leaf_avals for lo_ty in t.lo_ty()] + + def new_from_loval( # type: ignore[override] + self, variable_state: VariableQDD, *lo_vals + ) -> HijaxVariable: + lo_vals_ = iter(lo_vals) + hi_vals = [ + hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore + for hi_ty in variable_state.leaf_avals + ] + assert next(lo_vals_, None) is None + variable: Variable = jax.tree.unflatten(variable_state.treedef, hi_vals) + return HijaxVariable._new( + variable._raw_value, variable._var_metadata, self._var_type + ) # will be mutated + + def read_loval(self, variable_state: VariableQDD, variable) -> list: # type: ignore + leaf_vals, treedef = jax.tree.flatten(_get_hijax_state(variable)) + assert treedef == variable_state.treedef + return [ + lo_val + for hi_ty, hi_val in zip(variable_state.leaf_avals, leaf_vals) + for lo_val in hi_ty.lower_val(hi_val) + ] # type: ignore + + def update_from_loval( # type: ignore[override] + self, box_state: VariableQDD, variable, *lo_vals + ) -> None: + lo_vals_ = iter(lo_vals) + hi_vals = [ + hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore + for hi_ty in box_state.leaf_avals + ] + assert next(lo_vals_, None) is None + _set_hijax_state(variable, jax.tree.unflatten(box_state.treedef, hi_vals)) + + def to_tangent_aval(self): + return AbstractVariable(self._var_type) + + +# -------------------------------------------- +# Variable +# -------------------------------------------- + + +def _variable_operator(name: str) -> tp.Callable[[Variable[A], tp.Any], A]: + def variable_operator_method(self, other): + value = self.get_value() + if isinstance(other, Variable): + other = other.get_value() + return getattr(value, name)(other) + + variable_operator_method.__name__ = name + return variable_operator_method + + +def _variable_unary_operator(name: str) -> tp.Callable[[Variable[A]], A]: + def variable_unary_operator_method(self): + value = self.get_value() + return getattr(value, name)() + + variable_unary_operator_method.__name__ = name + return variable_unary_operator_method + + class VariableMeta(type): def __new__(cls, cls_name, bases, attrs): if '__slots__' not in attrs: attrs['__slots__'] = () return super().__new__(cls, cls_name, bases, attrs) + def __instancecheck__(self, instance): + if super().__instancecheck__(instance): + return True + + if isinstance(instance, jax_core.Tracer): + ty = jax_core.typeof(instance) + if isinstance(ty, AbstractVariable): + return issubclass(ty._var_type, self) + if isinstance(instance, HijaxVariable): + return issubclass(instance._var_type, self) + return False + + if not tp.TYPE_CHECKING: + + def __call__(cls, *args, **kwargs): + return cls._variable_meta_call(*args, **kwargs) + + def _variable_meta_call( + cls, + *args, + is_hijax: bool | None = None, + **kwargs, + ): + if is_hijax is None: + is_hijax = using_hijax() + variable = super().__call__(*args, is_hijax=is_hijax, **kwargs) + if is_hijax: + return _new_hijax_from_variable(variable) + return variable + class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): """The base class for all ``Variable`` types. Create custom ``Variable`` @@ -239,32 +890,88 @@ class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): }) """ - __slots__ = ('raw_value', '_trace_state', '_var_metadata') - - raw_value: A + __slots__ = ('_raw_value', '_trace_state', '_var_metadata') + _raw_value: A _trace_state: tracers.TraceState _var_metadata: dict[str, tp.Any] + @property + def var_type(self): + return type(self) + + @property + def is_hijax(self) -> bool: + return self._var_metadata['is_hijax'] + + @property + def has_ref(self) -> bool: + return self._var_metadata['has_ref'] + + @property + def is_mutable(self) -> bool: + return self._var_metadata['is_mutable'] + + @property + def shape(self: Variable[jax.Array]) -> tuple[int, ...]: + return self.get_value().shape + def __init__( self, - value: tp.Union[A, VariableMetadata[A]], + value: A | VariableMetadata[A], *, - use_ref: bool | None = None, + is_hijax: bool = False, + has_ref: bool = False, + is_mutable: bool = True, + eager_sharding: bool | None = None, **metadata: tp.Any, ): - if use_ref is None: - use_ref = using_refs() - var_t = type(self) - object.__setattr__(self, '_trace_state', tracers.TraceState()) if isinstance(value, VariableMetadata): - metadata.update(value.metadata) + aux_metadata = dict(value.metadata) + if 'is_hijax' in aux_metadata: + if is_hijax is not None and is_hijax != aux_metadata['is_hijax']: + raise ValueError( + 'Cannot specify is_hijax both in VariableMetadata and as an ' + 'argument to Variable constructor.' + ) + is_hijax = aux_metadata.pop('is_hijax') + if 'has_ref' in aux_metadata: + if has_ref is not None and has_ref != aux_metadata['has_ref']: + raise ValueError( + 'Cannot specify has_ref both in VariableMetadata and as an ' + 'argument to Variable constructor.' + ) + has_ref = aux_metadata.pop('has_ref') + if 'is_mutable' in aux_metadata: + if is_mutable is not None and is_mutable != aux_metadata['is_mutable']: + raise ValueError( + 'Cannot specify is_mutable both in VariableMetadata and as an ' + 'argument to Variable constructor.' + ) + is_mutable = aux_metadata.pop('is_mutable') + if 'eager_sharding' in aux_metadata: + if ( + eager_sharding is not None + and eager_sharding != aux_metadata['eager_sharding'] + ): + raise ValueError( + 'Cannot specify eager_sharding both in VariableMetadata and as ' + 'an argument to Variable constructor.' + ) + eager_sharding = aux_metadata.pop('eager_sharding') + metadata.update(aux_metadata) value = tp.cast(A, value.raw_value) - elif is_array_ref(value): - raise ValueError('Cannot pass a ArrayRef directly into Variable init.') - object.__setattr__(self, 'raw_value', value) + if any(is_array_ref(v) for v in jax.tree.leaves(value)): + raise ValueError('Cannot pass a Ref directly into Variable constructor.') + + metadata['is_hijax'] = is_hijax + metadata['has_ref'] = has_ref + metadata['is_mutable'] = is_mutable + object.__setattr__(self, '_trace_state', tracers.TraceState()) + object.__setattr__(self, '_var_metadata', metadata) + object.__setattr__(self, '_raw_value', value) if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata: metadata['on_get_value'] = var_t.on_get_value @@ -284,37 +991,54 @@ def __init__( if 'sharding' in metadata: metadata['sharding_names'] = metadata.pop('sharding') - object.__setattr__(self, '_var_metadata', metadata) # run create_value hooks - value = self.create_value(self.raw_value) - - # shard the value if applicable - do_eager_sharding = config.flax_always_shard_variable - if 'eager_sharding' in metadata: - do_eager_sharding = metadata['eager_sharding'] - if do_eager_sharding and 'sharding_names' in metadata: + if 'on_create_value' in metadata: + value = metadata['on_create_value'](self, value) + + if eager_sharding is None: + eager_sharding = config.flax_always_shard_variable + + object.__setattr__(self, '_raw_value', value) + # run create_value hook + value = self.create_value(value) # type: ignore + # shard the _value if applicable + if eager_sharding and 'sharding_names' in metadata: + metadata['eager_sharding'] = eager_sharding value = core_spmd.shard_value( - value, metadata['sharding_names'], metadata.get('sharding_rules', None), - metadata.get('mesh', None)) + value, + metadata['sharding_names'], + metadata.get('sharding_rules', None), + metadata.get('mesh', None), + ) + if has_ref: + value = jax.new_ref(value) # type: ignore + object.__setattr__(self, '_raw_value', value) - # Create the ref out of the array value - if use_ref: - value = new_ref(jnp.asarray(value)) # type: ignore[assignment] # type: ignore[assignment] + @property + def _can_update(self) -> bool: + """Whether the Variable can be updated in-place in the current trace context.""" + if self.is_hijax: + return self.is_mutable + else: + return self.is_mutable and self._trace_state.is_valid() - object.__setattr__(self, 'raw_value', value) + def _check_can_update(self): + if not self.is_mutable: + raise errors.ImmutableVariableError( + f'Cannot mutate {type(self).__name__} as it is marked as immutable.' + ) + if not self.is_hijax and not self._trace_state.is_valid(): + raise errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' + ) def __getattr__(self, name: str) -> tp.Any: if name in object.__getattribute__(self, '_var_metadata'): return self._var_metadata[name] - return getattr(self.raw_value, name) + return getattr(object.__getattribute__(self, '_raw_value'), name) def __setattr__(self, name: str, value: tp.Any): - if not self._trace_state.is_valid() and ( - name != 'value' or not self.has_ref - ): - raise errors.TraceContextError( - f'Cannot mutate {type(self).__name__} from a different trace level' - ) + self._check_can_update() try: object.__setattr__(self, name, value) except AttributeError as e: @@ -326,36 +1050,22 @@ def __setattr__(self, name: str, value: tp.Any): ) from e def __delattr__(self, name: str): - if not self._trace_state.is_valid(): - raise errors.TraceContextError( - f'Cannot mutate {type(self).__name__} from a different trace level' - ) - - if ( - name == 'value' - or name == 'raw_value' - or name == '_var_metadata' - or name == '_trace_state' - ): + self._check_can_update() + try: object.__delattr__(self, name) - else: - del self._var_metadata[name] + except AttributeError as e: + raise AttributeError( + f'Cannot delete attribute {name}. ' + f'To delete Variable metadata use:\n\n' + f" variable.del_metadata('{name}')" + ) from e # NOTE(cgarciae): adding this for backward compatibility with VariableState @property def type(self): """The type of the variable.""" - import warnings - warnings.warn( - "'.type' is deprecated, use 'type(variable)' instead.", - DeprecationWarning, - stacklevel=2, - ) - return type(self) - @property - def has_ref(self) -> bool: - return is_array_ref(self.raw_value) + return type(self) @tp.overload def get_metadata(self) -> dict[str, tp.Any]: ... @@ -372,11 +1082,12 @@ def get_metadata( default: The default value to return if the metadata key is not found. If not provided and the key is not found, raises a KeyError. """ + metadata = self._var_metadata.copy() if name is None: - return self._var_metadata - if name not in self._var_metadata and not isinstance(default, Missing): + return metadata + if name not in metadata and not isinstance(default, Missing): return default - return self._var_metadata[name] + return metadata[name] @tp.overload def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... @@ -396,20 +1107,66 @@ def set_metadata(self, *args, **kwargs) -> None: 3. By using keyword arguments, this will update the Variable's metadata with the provided key-value pairs. """ - if not self._trace_state.is_valid(): - raise errors.TraceContextError( - f'Cannot mutate {type(self).__name__} from a different trace level' - ) + self._check_can_update() if args and kwargs: raise TypeError( 'Cannot mix positional and keyword arguments in set_metadata' ) if len(args) == 1: - self._var_metadata = dict(args[0]) + metadata = dict(args[0]) + if 'is_hijax' not in metadata: + raise ValueError('metadata is missing required key `is_hijax` key') + if metadata['is_hijax'] != self.is_hijax: + raise ValueError( + f'Cannot change `is_hijax` metadata, expected {self.is_hijax}, ' + f'got {metadata["is_hijax"]}' + ) + if 'has_ref' not in metadata: + raise ValueError('metadata is missing required key `has_ref` key') + if metadata['has_ref'] != self.has_ref: + raise ValueError( + f'Cannot change `has_ref` metadata, expected {self.has_ref}, ' + f'got {metadata["has_ref"]}' + ) + if 'is_mutable' not in metadata: + raise ValueError('metadata is missing required key `is_mutable` key') + if metadata['is_mutable'] != self.is_mutable: + raise ValueError( + f'Cannot change `is_mutable` metadata, expected {self.is_mutable}, ' + f'got {metadata["is_mutable"]}' + ) + self._var_metadata = metadata elif len(args) == 2: name, value = args + if name == 'is_hijax' and value != self.is_hijax: + raise ValueError( + f'Cannot change `is_hijax` metadata, expected {self.is_hijax}, got {value}' + ) + if name == 'has_ref' and value != self.has_ref: + raise ValueError( + f'Cannot change `has_ref` metadata, expected {self.has_ref}, got {value}' + ) + if name == 'is_mutable' and value != self.is_mutable: + raise ValueError( + f'Cannot change `is_mutable` metadata, expected {self.is_mutable}, got {value}' + ) self._var_metadata[name] = value elif kwargs: + if 'is_hijax' in kwargs and kwargs['is_hijax'] != self.is_hijax: + raise ValueError( + f'Cannot change `is_hijax` metadata, expected {self.is_hijax}, ' + f'got {kwargs["is_hijax"]}' + ) + if 'has_ref' in kwargs and kwargs['has_ref'] != self.has_ref: + raise ValueError( + f'Cannot change `has_ref` metadata, expected {self.has_ref}, ' + f'got {kwargs["has_ref"]}' + ) + if 'is_mutable' in kwargs and kwargs['is_mutable'] != self.is_mutable: + raise ValueError( + f'Cannot change `is_mutable` metadata, expected {self.is_mutable}, ' + f'got {kwargs["is_mutable"]}' + ) self._var_metadata.update(kwargs) else: raise TypeError( @@ -417,6 +1174,17 @@ def set_metadata(self, *args, **kwargs) -> None: f'got args={args}, kwargs={kwargs}' ) + def del_metadata(self, name: str) -> None: + """Delete a metadata entry for the Variable. + + Args: + name: The key of the metadata element to delete. + """ + self._check_can_update() + if name in ('is_hijax', 'has_ref', 'is_mutable'): + raise ValueError(f'Cannot delete `{name}` metadata') + del self._var_metadata[name] + def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( @@ -425,50 +1193,132 @@ def copy_from(self, other: Variable[A]) -> None: ) if self is other: return - self.raw_value = other.raw_value + self._raw_value = other._raw_value self._var_metadata.clear() self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: Variable[A]): - if self.has_ref and ( - variable_state.has_ref or isinstance(variable_state.raw_value, jax.Array) - ): - self.raw_value[...] = variable_state.raw_value[...] # type: ignore - else: - object.__setattr__(self, 'raw_value', variable_state.raw_value) + self._raw_value = variable_state._raw_value if self._var_metadata != variable_state._var_metadata: - object.__setattr__( - self, '_var_metadata', variable_state._var_metadata.copy() - ) + metadata = variable_state.get_metadata() + metadata['is_hijax'] = self.is_hijax + metadata['has_ref'] = self.has_ref + metadata['is_mutable'] = self.is_mutable + self._var_metadata = metadata + + @tp.final + def get_raw_value(self) -> A: + return self._raw_value + + # @tp.final + def set_raw_value(self, value: A, *, _unsafe_bypass_check: bool = False): + if not _unsafe_bypass_check: + self._check_can_update() + self._raw_value = value + + @property + def raw_value(self) -> A: + warnings.warn( + "'.raw_value' access is now deprecated. Use:\n\n" + ' variable.get_raw_value()\n', + DeprecationWarning, + stacklevel=2, + ) + return self.get_raw_value() + + @raw_value.setter + def raw_value(self, value: A): + warnings.warn( + "'.raw_value' access is now deprecated. Use:\n\n" + ' variable.set_raw_value(value)\n', + DeprecationWarning, + stacklevel=2, + ) + self.set_raw_value(value) @property def value(self) -> A: - value = self.raw_value + warnings.warn( + "'.value' access is now deprecated. For Variable[Array] instances use:\n\n" + ' variable[...]\n\n' + 'For other Variable types use:\n\n' + ' variable.get_value()\n', + DeprecationWarning, + stacklevel=2, + ) + value = self._raw_value if is_array_ref(value): value = value[...] - if 'on_get_value' in self._var_metadata: - value = self._var_metadata['on_get_value'](self, value) - return value + return self.get_value() @value.setter def value(self, value: A): + warnings.warn( + "'.value' access is now deprecated. For Variable[Array] instances use:\n\n" + ' variable[...] = value\n\n' + 'For other Variable types use:\n\n' + ' variable.set_value(value)\n', + DeprecationWarning, + stacklevel=2, + ) + self.set_value(value) + + def create_value(self, value: A): + return value + + def get_value(self, *, index: tp.Any = MISSING) -> A: + value = jax.tree.map(lambda x: x, self._raw_value) # make a copy + if not isinstance(index, Missing): + if is_array_ref(value): + value = value[index] + elif isinstance(value, jax.Array) and index == ...: + pass # skip trivial access + else: + value = value[index] + elif is_array_ref(value): + value = value[...] + if 'on_get_value' in self._var_metadata: + value = self._var_metadata['on_get_value'](self, value) + return value # type: ignore + + def set_value(self, value: A, *, index: tp.Any = MISSING): + value = jax.tree.map(lambda x: x, value) # make a copy if isinstance(value, Variable): raise ValueError( 'Cannot set value to a Variable, use `copy_from` method instead' ) if 'on_set_value' in self._var_metadata: value = self._var_metadata['on_set_value'](self, value) - if self.has_ref: - self.raw_value[...] = value # type: ignore + # update _raw_value + if is_array_ref(self._raw_value): + if isinstance(index, Missing): + self._raw_value[...] = value + else: + self._raw_value[index] = value + elif isinstance(self._raw_value, jax.Array) and ( + not isinstance(index, Missing) + ): + # check if its a full replace to av + if ( + index == ... + and isinstance(value, jax.Array) + and value.shape == self._raw_value[index].shape + and value.dtype == self._raw_value.dtype + and ( + getattr(value, 'sharding', None) + == getattr(self._raw_value, 'sharding', None) + ) + ): + self._raw_value = value + else: + self._raw_value = self._raw_value.at[index].set(value) # type: ignore else: - object.__setattr__(self, 'raw_value', value) - - def create_value(self, value: A): - if 'on_create_value' in self._var_metadata: - value = self._var_metadata['on_create_value'](self, value) - return value + if isinstance(index, Missing): + self._raw_value = value + else: + self._raw_value[index] = value # type: ignore def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_add_axis' in self._var_metadata: @@ -479,78 +1329,114 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): self._var_metadata['on_remove_axis'](self, axis_index, axis_name) @tp.overload - def replace(self, value: B, **kwargs) -> Variable[B]: - ... + def copy(self, value: B, **kwargs) -> Variable[B]: ... @tp.overload - def replace(self, **kwargs) -> Variable[A]: - ... + def copy(self, **kwargs) -> Variable[A]: ... - def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: - if value is not Missing: - kwargs['raw_value'] = value + def copy( + self, + value: tp.Any = MISSING, + *, + _copy_ref: bool = True, + **updates, + ) -> Variable[tp.Any]: + assert 'raw_value' not in updates + + new_metadata = self.get_metadata() | updates + if 'is_mutable' in updates and not updates['is_mutable']: + if 'has_ref' in updates and updates['has_ref']: + raise ValueError( + 'Cannot set has_ref=True and is_mutable=False simultaneously.' + ) + if 'is_hijax' in updates and updates['is_hijax']: + raise ValueError( + 'Cannot set is_hijax=True and is_mutable=False simultaneously.' + ) + if self.is_mutable: + new_metadata['has_ref'] = False + new_metadata['is_hijax'] = False + if self.has_ref: + new_metadata['had_ref'] = True + if self.is_hijax: + new_metadata['was_hijax'] = True + if ('is_mutable' in updates and updates['is_mutable']) or ( + 'has_ref' in updates and updates['has_ref'] + ): + new_metadata.pop('had_ref', None) + if ('is_mutable' in updates and updates['is_mutable']) or ( + 'is_hijax' in updates and updates['is_hijax'] + ): + new_metadata.pop('was_hijax', None) - # rename `value` to `raw_value` - if 'value' in kwargs: - kwargs['raw_value'] = kwargs.pop('value') + if not isinstance(value, Missing): + pass + elif 'value' in updates: + value = updates.pop('value') + else: + value = self.get_raw_value() + if _copy_ref and is_array_ref(value): + value = value[...] - # return `value` if it is a Variable - if 'raw_value' in kwargs and isinstance( - value := kwargs['raw_value'], Variable + if _copy_ref and ( + new_metadata['has_ref'] + or (new_metadata['is_mutable'] and self.get_metadata('had_ref', False)) ): - # remove value from kwargs - kwargs.pop('raw_value') - if type(self) is not type(value): - raise ValueError( - 'Cannot replace value from incompatible container, ' - f'expected {type(self).__name__}, got {type(value).__name__}' - ) - # if kwargs aren't empty, recursively call replace - # else return variable value - if kwargs: - return value.replace(**kwargs) - else: - return value - - # get and update attributes - # return new instance with updated attributes - obj = object.__new__(type(self)) - object.__setattr__(obj, '_trace_state', self._trace_state) - object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value')) - object.__setattr__(obj, '_var_metadata', self.get_metadata() | kwargs) + value = jax.new_ref(value) + new_metadata['has_ref'] = True + if new_metadata['is_mutable'] and self.get_metadata('was_hijax', False): + new_metadata['is_hijax'] = True + + obj = self.from_metadata(value, new_metadata) return obj @classmethod - def from_metadata(cls, value: A, attributes: dict[str, tp.Any]): + def _new( + cls, + value: A, + metadata: dict[str, tp.Any], + ) -> Variable[A]: obj = object.__new__(cls) + # skip __setattr__ for trace_state initialization object.__setattr__(obj, '_trace_state', tracers.TraceState()) - object.__setattr__(obj, 'raw_value', value) - object.__setattr__(obj, '_var_metadata', attributes) - return obj - - def copy(self: Variable[A]) -> Variable[A]: - obj = object.__new__(type(self)) - object.__setattr__(obj, '_trace_state', tracers.TraceState()) - object.__setattr__(obj, 'raw_value', self.raw_value) - object.__setattr__(obj, '_var_metadata', self.get_metadata().copy()) + object.__setattr__(obj, '_var_metadata', metadata) + object.__setattr__(obj, '_raw_value', value) return obj + @classmethod + def from_metadata( + cls, + value: A, + attributes: dict[str, tp.Any], + ) -> Variable[A]: + variable = cls._new(value, dict(attributes)) + if attributes['is_hijax']: + variable = _new_hijax_from_variable(variable) # type: ignore[assignment] + return variable # type: ignore[return-value] + + replace = copy to_state = copy def __nnx_repr__(self): - stats = SizeBytes.from_any(self.raw_value) + stats = SizeBytes.from_any(self._raw_value) if stats: comment = f' # {stats}' else: comment = '' yield reprlib.Object(type=type(self).__name__, comment=comment) - yield reprlib.Attr('value', self.raw_value) + yield reprlib.Attr('value', self.get_value()) for name, value in self._var_metadata.items(): + if name == 'is_hijax' and not value: + continue + if name == 'has_ref' and not value: + continue + if name == 'is_mutable' and value: + continue yield reprlib.Attr(name, value) def __treescope_repr__(self, path, subtree_renderer): - size_bytes = SizeBytes.from_any(self.value) + size_bytes = SizeBytes.from_any(self.get_value()) if size_bytes: stats_repr = f' # {size_bytes}' first_line_annotation = treescope.rendering_parts.comment_color( @@ -559,7 +1445,7 @@ def __treescope_repr__(self, path, subtree_renderer): else: first_line_annotation = None - children = {'value': self.raw_value, **self._var_metadata} + children = {'value': self.get_value(), **self._var_metadata} return visualization.render_object_constructor( object_type=type(self), attributes=children, @@ -586,20 +1472,21 @@ def on_remove_axis( ) -> V: ... def __jax_array__(self): - return self.value + return self.get_value() # pickle support def __getstate__(self): return { - 'raw_value': self.raw_value, + '_raw_value': self._raw_value, '_trace_state': self._trace_state, '_var_metadata': self._var_metadata, } def __setstate__(self, state): - object.__setattr__(self, 'raw_value', state['raw_value']) + # skip __setattr__ for trace_state initialization object.__setattr__(self, '_trace_state', state['_trace_state']) object.__setattr__(self, '_var_metadata', state['_var_metadata']) + object.__setattr__(self, '_raw_value', state['_raw_value']) # -------------------------------------------- # proxy methods @@ -615,167 +1502,54 @@ def __getitem__(self: Variable[tuple[B, ...]], key: int) -> B: ... @tp.overload def __getitem__(self, key) -> tp.Any: ... def __getitem__(self, key): - return self.value[key] # type: ignore + return self.get_value(index=key) - def __setitem__(self, key, item_value) -> None: - value = self.value - if isinstance(value, jax.Array): - value = value.at[key].set(item_value) # type: ignore[assignment] - else: - value[key] = item_value # type: ignore - self.value = value # type: ignore + def __setitem__(self, key, value) -> None: + self.set_value(value, index=key) + + def __delitem__(self, key) -> None: + value = self.get_value() + del value[key] # type: ignore + self.set_value(value) # type: ignore def __call__(self, *args, **kwargs) -> tp.Any: - return self.value(*args, **kwargs) # type: ignore + return self.get_value()(*args, **kwargs) # type: ignore def __len__(self) -> int: - return len(self.value) # type: ignore + return len(self.get_value()) # type: ignore def __iter__(self) -> tp.Iterator: - return iter(self.value) # type: ignore + return iter(self.get_value()) # type: ignore def __contains__(self, item) -> bool: - return item in self.value # type: ignore - - def __add__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__add__(other) # type: ignore - - def __sub__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__sub__(other) # type: ignore - - def __mul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__mul__(other) # type: ignore - - def __matmul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__matmul__(other) # type: ignore - - def __truediv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__truediv__(other) # type: ignore - - def __floordiv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__floordiv__(other) # type: ignore - - def __mod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__mod__(other) # type: ignore - - def __divmod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__divmod__(other) # type: ignore - - def __pow__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__pow__(other) # type: ignore - - def __lshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__lshift__(other) # type: ignore - - def __rshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rshift__(other) # type: ignore - - def __and__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__and__(other) # type: ignore - - def __xor__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__xor__(other) # type: ignore - - def __or__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__or__(other) # type: ignore - - def __radd__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__radd__(other) # type: ignore - - def __rsub__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rsub__(other) # type: ignore - - def __rmul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rmul__(other) # type: ignore - - def __rmatmul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rmatmul__(other) # type: ignore - - def __rtruediv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rtruediv__(other) # type: ignore - - def __rfloordiv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rfloordiv__(other) # type: ignore - - def __rmod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rmod__(other) # type: ignore - - def __rdivmod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rdivmod__(other) # type: ignore - - def __rpow__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rpow__(other) # type: ignore - - def __rlshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rlshift__(other) # type: ignore - - def __rrshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rrshift__(other) # type: ignore - - def __rand__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rand__(other) # type: ignore - - def __rxor__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rxor__(other) # type: ignore - - def __ror__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__ror__(other) # type: ignore + return item in self.get_value() # type: ignore + + __add__ = _variable_operator('__add__') + __sub__ = _variable_operator('__sub__') + __mul__ = _variable_operator('__mul__') + __matmul__ = _variable_operator('__matmul__') + __truediv__ = _variable_operator('__truediv__') + __floordiv__ = _variable_operator('__floordiv__') + __mod__ = _variable_operator('__mod__') + __pow__ = _variable_operator('__pow__') + __lshift__ = _variable_operator('__lshift__') + __rshift__ = _variable_operator('__rshift__') + __and__ = _variable_operator('__and__') + __xor__ = _variable_operator('__xor__') + __or__ = _variable_operator('__or__') + __radd__ = _variable_operator('__radd__') + __rsub__ = _variable_operator('__rsub__') + __rmul__ = _variable_operator('__rmul__') + __rmatmul__ = _variable_operator('__rmatmul__') + __rtruediv__ = _variable_operator('__rtruediv__') + __rfloordiv__ = _variable_operator('__rfloordiv__') + __rmod__ = _variable_operator('__rmod__') + __rpow__ = _variable_operator('__rpow__') + __rlshift__ = _variable_operator('__rlshift__') + __rrshift__ = _variable_operator('__rrshift__') + __rand__ = _variable_operator('__rand__') + __rxor__ = _variable_operator('__rxor__') + __ror__ = _variable_operator('__ror__') def __iadd__(self: V, other) -> V: raise NotImplementedError( @@ -855,47 +1629,27 @@ def __ior__(self: V, other) -> V: 'Use `variable.value |= x` instead.' ) - def __neg__(self) -> A: - return self.value.__neg__() # type: ignore - - def __pos__(self) -> A: - return self.value.__pos__() # type: ignore - - def __abs__(self) -> A: - return self.value.__abs__() # type: ignore - - def __invert__(self) -> A: - return self.value.__invert__() # type: ignore - - def __complex__(self) -> A: - return self.value.__complex__() # type: ignore - - def __int__(self) -> A: - return self.value.__int__() # type: ignore - - def __float__(self) -> A: - return self.value.__float__() # type: ignore - - def __index__(self) -> A: - return self.value.__index__() # type: ignore - - def __round__(self, ndigits: int) -> A: - return self.value.__round__(ndigits) # type: ignore - - def __trunc__(self) -> A: - return self.value.__trunc__() # type: ignore - - def __floor__(self) -> A: - return self.value.__floor__() # type: ignore - - def __ceil__(self) -> A: - return self.value.__ceil__() # type: ignore + __neg__ = _variable_unary_operator('__neg__') + __pos__ = _variable_unary_operator('__pos__') + __abs__ = _variable_unary_operator('__abs__') + __invert__ = _variable_unary_operator('__invert__') + __complex__ = _variable_unary_operator('__complex__') + __int__ = _variable_unary_operator('__int__') + __float__ = _variable_unary_operator('__float__') + __index__ = _variable_unary_operator('__index__') + __trunc__ = _variable_unary_operator('__trunc__') + __floor__ = _variable_unary_operator('__floor__') + __ceil__ = _variable_unary_operator('__ceil__') + + def __round__(self, ndigits: int = 0) -> A: + return self.get_value().__round__(ndigits) # type: ignore # -------------------------------------------- def __init_subclass__(cls) -> None: + if '__slots__' not in vars(cls): + cls.__slots__ = () # type: ignore[assignment] super().__init_subclass__() - jax.tree_util.register_pytree_with_keys( cls, flatten_with_keys=_variable_flatten_with_keys, @@ -906,13 +1660,13 @@ def __init_subclass__(cls) -> None: def _variable_flatten_with_keys(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) - node = (jtu.GetAttrKey('value'), x.raw_value) + node = (jtu.GetAttrKey('value'), x._raw_value) return (node,), metadata def _variable_flatten(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) - return (x.raw_value,), metadata + return (x._raw_value,), metadata def _variable_unflatten( @@ -920,7 +1674,7 @@ def _variable_unflatten( static: tuple[tuple[str, tp.Any], ...], children: tuple[tp.Any], ): - return cls.from_metadata(value=children[0], attributes=dict(static)) + return cls._new(children[0], dict(static)) jax.tree_util.register_pytree_with_keys( @@ -930,9 +1684,9 @@ def _variable_unflatten( flatten_func=_variable_flatten, ) - VariableState = Variable + class Param(Variable[A]): """The canonical learnable parameter. All learnable parameters in NNX layer modules will have the ``Param`` :class:`Variable` @@ -1153,32 +1907,6 @@ def wrapper(*args): return wrapper # type: ignore -def split_flat_state( - flat_state: tp.Iterable[tuple[PathParts, Variable]], - filters: tuple[filterlib.Filter, ...], -) -> tuple[list[tuple[PathParts, Variable]], ...]: - predicates = filterlib.filters_to_predicates(filters) - # we have n + 1 states, where n is the number of predicates - # the last state is for values that don't match any predicate - flat_states: tuple[list[tuple[PathParts, Variable]], ...] = ( - tuple([] for _ in predicates) - ) - - for path, value in flat_state: - for i, predicate in enumerate(predicates): - if predicate(path, value): - flat_states[i].append((path, value)) - break - else: - raise ValueError( - 'Non-exhaustive filters, got a non-empty remainder: ' - f'{path} -> {value}.' - '\nUse `...` to match all remaining elements.' - ) - - return flat_states - - ################################################### ### Variable type/class <-> string name mapping ### ################################################### @@ -1232,13 +1960,6 @@ def variable_name_from_type( return name -class _Missing: - pass - - -_MISSING = _Missing() - - @tp.overload def register_variable_name( name: str, @@ -1258,12 +1979,12 @@ def register_variable_name( def register_variable_name( name: str, - typ: type[Variable[A]] | _Missing = _MISSING, + typ: type[Variable[A]] | Missing = MISSING, *, overwrite=False, ) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]: """Register a pair of Linen collection name and its NNX type.""" - if typ is _MISSING: + if isinstance(typ, Missing): return partial(register_variable_name, name, overwrite=overwrite) typ = tp.cast(type[Variable[A]], typ) if not overwrite and name in VariableTypeCache: diff --git a/tests/nnx/containers_test.py b/tests/nnx/containers_test.py index 70b51283f..7aa80cb73 100644 --- a/tests/nnx/containers_test.py +++ b/tests/nnx/containers_test.py @@ -34,7 +34,7 @@ def test_on_set_value(self): ) x[...] = 5 - assert x.raw_value == 12 + assert x.get_raw_value() == 12 def test_module_unbox(self): class Foo(nnx.Module): @@ -43,8 +43,8 @@ def __init__(self) -> None: module = Foo() - assert module.x.value == 4 - assert vars(module)['x'].raw_value == 1 + assert module.x.get_value() == 4 + assert vars(module)['x'].get_raw_value() == 1 def test_module_box(self): class Foo(nnx.Module): @@ -58,7 +58,7 @@ def __init__(self) -> None: module.x[...] = 5 assert module.x[...] == 12 - assert vars(module)['x'].raw_value == 12 + assert vars(module)['x'][...] == 12 if __name__ == '__main__': diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 5dde7c260..1bff96a59 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -46,8 +46,8 @@ def test_flatten(self): refmap = nnx.graph.RefMap() graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap) - assert flat_state[0][1].value == 2 - assert flat_state[1][1].value == 4 + assert flat_state[0][1].get_value() == 2 + assert flat_state[1][1].get_value() == 4 assert len(refmap) == 2 # 2 Variables assert a['b'] in refmap @@ -156,8 +156,8 @@ def test_update_dynamic(self): state[0]['b'][...] = 3 nnx.update(g, state) - assert g[0]['b'].value == 3 - assert g[2]['b'].value == 3 + assert g[0]['b'][...] == 3 + assert g[2]['b'][...] == 3 def test_update_from_pure_dict(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} @@ -342,7 +342,7 @@ def __init__(self): m2 = nnx.merge(graphdef, state) assert isinstance(m2.tree, Tree) - assert m2.tree.a.raw_value == 1 + assert m2.tree.a.get_value() == 1 assert m2.tree.b == 'a' assert m2.tree.a is m.tree.a assert m2.tree is not m.tree diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 67eaa71c5..c9aabb62d 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -146,17 +146,17 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - self.count = State(0) + self.count = State(jnp.array(0)) def __call__(self, x): - self.count.value += 1 - return x @ self.w.value + self.b.value[None] + self.count[...] += 1 + return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) - assert model.count.value == 1 + assert model.count[...] == 1 @nnx.jit def train_step(model, x, y): @@ -176,7 +176,7 @@ def loss_fn(model): # execute the training step train_step(model, x, y) - assert model.count.value == 2 + assert model.count[...] == 2 def test_functional_example(self): class Count(nnx.Variable[A]): @@ -187,17 +187,17 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - self.count = Count(0) + self.count = Count(jnp.array(0)) def __call__(self, x): - self.count.value += 1 - return x @ self.w.value + self.b.value[None] + self.count[...] += 1 + return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) - assert model.count.value == 1 + assert model.count[...] == 1 graphdef, params, counts = nnx.split(model, nnx.Param, Count) @@ -218,7 +218,7 @@ def loss_fn(params): # execute the training step params, counts = train_step(params, counts, x, y) model = nnx.merge(graphdef, params, counts) - assert model.count.value == 2 + assert model.count[...] == 2 def test_intermediates_example(self): class Linear(nnx.Module): @@ -228,7 +228,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - y = x @ self.w.value + self.b.value[None] + y = x @ self.w + self.b[None] self.y = nnx.Intermediate(y) return y @@ -248,7 +248,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - y = x @ self.w.value + self.b.value[None] + y = x @ self.w + self.b[None] self.y = nnx.Intermediate(y) return y @@ -298,6 +298,7 @@ def __call__(self, x): nnx.update(model, restored_pure_dict) assert model(x).shape == (3, 4) # The model still works! + @nnx.use_hijax(True) def test_example_mutable_arrays(self): class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): @@ -310,18 +311,17 @@ def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) - with nnx.use_refs(True): - model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization - optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @jax.jit # automatic state management for JAX transforms def train_step(x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) - return ((model(x) - y) ** 2).mean() # call methods directly + return ((model(x) - y) ** 2).mean() # call methods directly - loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.immutable(params)) optimizer.update(model, grads) # in-place updates return loss diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 64bf1dda2..a5e17bd7a 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -1,4 +1,5 @@ # Copyright 2024 The Flax Authors. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -333,13 +334,13 @@ def test_clone(self): m2 = nnx.clone(m) assert m is not m2 - assert m2.a[0].value == m2.b.c.value - assert m2.a[1].value == m2.b.d.value + assert m2.a[0].get_value() == m2.b.c.get_value() + assert m2.a[1].get_value() == m2.b.d.get_value() - assert m.a[0].value == m2.a[0].value - assert m.a[1].value == m2.a[1].value - assert m.b.c.value == m2.b.c.value - assert m.b.d.value == m2.b.d.value + assert m.a[0].get_value() == m2.a[0].get_value() + assert m.a[1].get_value() == m2.a[1].get_value() + assert m.b.c.get_value() == m2.b.c.get_value() + assert m.b.d.get_value() == m2.b.d.get_value() def test_sow_basic(self): class Foo(nnx.Module): @@ -354,12 +355,12 @@ def __call__(self, x): assert y1 == 3 assert y2 == 11 - assert m.y.value == (3, 11) + assert m.y.get_value() == (3, 11) intermediates = nnx.pop(m, nnx.Intermediate) assert isinstance(intermediates['y'], nnx.Intermediate) - assert intermediates['y'].value == (3, 11) + assert intermediates['y'].get_value() == (3, 11) assert not hasattr(m, 'y') @@ -550,13 +551,13 @@ def add_submodule(self): def test_create_abstract(self): linear = nnx.eval_shape(lambda: nnx.Linear(2, 3, rngs=nnx.Rngs(0))) - assert linear.kernel.value == jax.ShapeDtypeStruct((2, 3), jnp.float32) - assert linear.bias.value == jax.ShapeDtypeStruct((3,), jnp.float32) + assert linear.kernel.get_value() == jax.ShapeDtypeStruct((2, 3), jnp.float32) + assert linear.bias.get_value() == jax.ShapeDtypeStruct((3,), jnp.float32) def test_create_abstract_stateful(self): linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0))) - assert linear.rngs.key.value == jax.ShapeDtypeStruct( + assert linear.rngs.key.get_value() == jax.ShapeDtypeStruct( (), jax.random.key(0).dtype ) @@ -742,13 +743,13 @@ class Foo(nnx.Module): graphdef, state = nnx.split(m) assert len(state) == 4 - assert state['b'].value == 2 + assert state['b'].get_value() == 2 assert isinstance(state['b'], nnx.Variable) - assert state['c'].value == 3 + assert state['c'].get_value() == 3 assert isinstance(state['c'], nnx.Param) - assert state['d'].value == 4 + assert state['d'].get_value() == 4 assert isinstance(state['d'], nnx.Variable) - assert state['e'].value == 5 + assert state['e'].get_value() == 5 assert isinstance(state['e'], nnx.BatchStat) def test_post_init(self): diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index d30a2489d..621f607f1 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -15,6 +15,7 @@ import dataclasses from absl.testing import absltest import optax +import pytest from flax import nnx import flax.errors import jax @@ -22,15 +23,7 @@ -class TestObject(absltest.TestCase): - @classmethod - def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) - - @classmethod - def tearDownClass(cls): - nnx.use_refs(cls.using_refs) +class TestPytree(absltest.TestCase): def test_pytree(self): class Foo(nnx.Module): @@ -38,7 +31,7 @@ def __init__(self): self.node = jnp.array(1) self.meta = 1 - m = Foo() + m = nnx.as_ref_vars(Foo()) m = jax.tree.map(lambda x: x + 1, m) @@ -122,18 +115,7 @@ def __init__(self, a): self.assertTrue(nnx.is_data(foo.a)) self.assertEqual(jax.tree.leaves(foo), [MyType(value=42)]) - - -class TestMutableArrayGraph(absltest.TestCase): - @classmethod - def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) - - @classmethod - def tearDownClass(cls): - nnx.use_refs(cls.using_refs) - +class TestVariableRefMode(absltest.TestCase): def test_split_mutable_array(self): m = jax.new_ref(1) graphdef, state = nnx.split(m) @@ -149,37 +131,33 @@ class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) - m = Foo() - self.assertTrue(m.a.has_ref) + m = nnx.as_ref_vars(Foo()) + self.assertEqual(m.a.has_ref, True) - m2 = nnx.to_arrays(m) - self.assertFalse(m2.a.has_ref) + m2 = nnx.immutable(m) + self.assertEqual(m2.a.has_ref, False) self.assertIsNot(m, m2) - m3 = nnx.to_refs(m2) - self.assertTrue(m3.a.has_ref) + m3 = nnx.as_ref_vars(m2) + self.assertEqual(m3.a.has_ref, True) self.assertIsNot(m2, m3) self.assertIsNot(m2.a, m3.a) def test_to_arrays_example(self): + node = [nnx.Variable(1.0), nnx.Variable(2.0, mode='ref')] + mutable_node = nnx.as_ref_vars(node) + assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) + assert isinstance(mutable_node[1].get_raw_value(), jax.Ref) - node = [jnp.array(1.0), jax.new_ref(jnp.array(2.0))] - mutable_node = nnx.to_refs(node) - assert isinstance(mutable_node[0], jax.Ref) - assert isinstance(mutable_node[1], jax.Ref) - - shared_array = jnp.array(1.0) + shared_array = nnx.Variable(1.0, mode='pytree') node = [shared_array, shared_array] - with self.assertRaisesRegex( - ValueError, - 'Found duplicate at path' - ): - nnx.to_refs(node) + with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): + nnx.as_ref_vars(node) - node = [jnp.array(1.0), jnp.array(2.0)] - mutable_node = nnx.to_refs(node, only=lambda path, x: path[0] == 0) - assert isinstance(mutable_node[0], jax.Ref) - assert isinstance(mutable_node[1], jax.Array) + node = [nnx.Variable(1.0), nnx.Variable(2.0)] + mutable_node = nnx.as_ref_vars(node, only=lambda path, x: path[0] == 0) + assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) + assert isinstance(mutable_node[1].get_raw_value(), float) def test_freeze_and_mutable_with_filter(self): class Foo(nnx.Module): @@ -187,31 +165,31 @@ def __init__(self): self.a = nnx.Param(1) self.b = nnx.BatchStat(2) - m = Foo() - self.assertTrue(m.a.has_ref) - self.assertTrue(m.b.has_ref) + m = nnx.as_ref_vars(Foo()) + self.assertEqual(m.a.has_ref, True) + self.assertEqual(m.b.has_ref, True) - m2 = nnx.to_arrays(m, only=nnx.BatchStat) - self.assertTrue(m2.a.has_ref) - self.assertFalse(m2.b.has_ref) + m2 = nnx.immutable(m, only=nnx.BatchStat) + self.assertEqual(m2.a.has_ref, True) + self.assertEqual(m2.b.has_ref, False) self.assertIsNot(m, m2) - m3 = nnx.to_refs(m2, nnx.BatchStat) - self.assertTrue(m3.a.has_ref) - self.assertTrue(m3.b.has_ref) + m3 = nnx.as_ref_vars(m2, only=nnx.BatchStat) + self.assertEqual(m3.a.has_ref, True) + self.assertEqual(m3.b.has_ref, True) self.assertIsNot(m2, m3) self.assertIs(m.a, m3.a) def test_freeze_duplicate_error(self): class Foo(nnx.Module): def __init__(self): - self.a = jax.new_ref(1) + self.a = nnx.Param(1, mode='ref') self.b = self.a m = Foo() with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): - nnx.to_arrays(m) + nnx.immutable(m) def test_mutable_array_split(self): class Foo(nnx.Module): @@ -233,7 +211,7 @@ def __init__(self): def test_mutable_array_split_merge_in_variable(self): class Foo(nnx.Module): def __init__(self): - self.a = nnx.Param(1, use_ref=True) + self.a = nnx.Param(1, has_ref=True) self.b = self.a m = Foo() @@ -241,7 +219,7 @@ def __init__(self): ref_map = nnx.graph.RefMap() graphdef, state = nnx.graph.flatten(m, ref_index=ref_map) self.assertLen(state, 1) - self.assertLen(ref_map, 3) # 1 Foo + 1 Param + 1 ArrayRef + self.assertLen(ref_map, 3) # 1 Foo + 1 Param + 1 Ref m1 = nnx.merge(graphdef, state) self.assertIs(m1.a, m1.b) @@ -251,27 +229,27 @@ def test_mutable_array_split_merge_in_variable_shared_array(self): class Foo(nnx.Module): def __init__(self): m_array = 1 - self.a = nnx.Param(m_array, use_ref=True) - self.b = nnx.Param(m_array, use_ref=True) + self.a = nnx.Param(m_array, has_ref=True) + self.b = nnx.Param(m_array, has_ref=True) m = Foo() - self.assertIsNot(m.a.raw_value, m.b.raw_value) ref_map = nnx.graph.RefMap() graphdef, state = nnx.graph.flatten(m, ref_index=ref_map) self.assertLen(state, 2) - self.assertLen(ref_map, 5) # 1 Foo + 2 Param + 2 ArrayRefs + self.assertLen(ref_map, 5) # 1 Foo + 2 Param + 2 Ref m1 = nnx.merge(graphdef, state) # Each variable will own its own array and ref. - self.assertIsNot(m1.a.raw_value, m1.b.raw_value) self.assertIsInstance(m1.a, nnx.Param) def test_mutable_example(self): - tree = [jnp.array(1.0), jax.new_ref(jnp.array(2.0))] - mutable_tree = nnx.to_refs(tree) - assert isinstance(mutable_tree[0], jax.Ref) - assert isinstance(mutable_tree[1], jax.Ref) + tree = [nnx.Variable(1.0), nnx.Variable(2.0, has_ref=True)] + assert tree[0].has_ref == False + assert tree[1].has_ref == True + mutable_tree = nnx.as_ref_vars(tree) + assert isinstance(mutable_tree[0].get_raw_value(), jax.Ref) + assert isinstance(mutable_tree[1].get_raw_value(), jax.Ref) def test_mutable_array_split_freeze(self): class Foo(nnx.Module): @@ -283,15 +261,15 @@ def __init__(self): ref_map = nnx.graph.RefMap() graphdef, state = nnx.graph.flatten(m, ref_index=ref_map) - state = nnx.to_arrays(state) + state = nnx.immutable(state) self.assertLen(state, 1) - m1 = nnx.merge(graphdef, nnx.to_refs(state)) + m1 = nnx.merge(graphdef, nnx.as_hijax_vars(state)) self.assertIs(m1.a, m1.b) self.assertIsInstance(m1.a, jax.Ref) def test_update_context(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.split(m1) @@ -299,18 +277,22 @@ def test_update_context(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.split((m2, m_out1, m2)) - self.assertIsInstance(state_out[0]['kernel'].value, nnx.graph.NoUpdate) - self.assertIsInstance(state_out[0]['bias'].value, nnx.graph.NoUpdate) self.assertIsInstance( - state_out[1]['kernel'].value, nnx.graph.ArrayRefOutput + state_out[0]['kernel'].get_value(), nnx.graph.NoUpdate + ) + self.assertIsInstance( + state_out[0]['bias'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - state_out[1]['bias'].value, nnx.graph.ArrayRefOutput + state_out[1]['kernel'].get_value(), nnx.graph.ArrayRefOutput + ) + self.assertIsInstance( + state_out[1]['bias'].get_value(), nnx.graph.ArrayRefOutput ) # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(state_out), 2) @@ -322,7 +304,7 @@ def test_update_context(self): self.assertIsNot(m_out2, m_out1) def test_update_context_flatten(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.flatten(m1) @@ -330,7 +312,7 @@ def test_update_context_flatten(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.flatten((m2, m_out1, m2)) @@ -338,16 +320,16 @@ def test_update_context_flatten(self): state_out_dict = dict(state_out) self.assertIsInstance( - state_out_dict[(0, 'kernel')].value, nnx.graph.NoUpdate + state_out_dict[(0, 'kernel')].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - state_out_dict[(0, 'bias')].value, nnx.graph.NoUpdate + state_out_dict[(0, 'bias')].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - state_out_dict[(1, 'kernel')].value, nnx.graph.ArrayRefOutput + state_out_dict[(1, 'kernel')].get_value(), nnx.graph.ArrayRefOutput ) self.assertIsInstance( - state_out_dict[(1, 'bias')].value, nnx.graph.ArrayRefOutput + state_out_dict[(1, 'bias')].get_value(), nnx.graph.ArrayRefOutput ) # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(state_out), 2) @@ -359,29 +341,29 @@ def test_update_context_flatten(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree1(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example') self.assertIsInstance( - out_tree[0][0].states[0]['kernel'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['kernel'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[0][0].states[0]['bias'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['bias'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[1].states[0]['kernel'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['kernel'].get_value(), nnx.graph.ArrayRefOutput ) self.assertIsInstance( - out_tree[1].states[0]['bias'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['bias'].get_value(), nnx.graph.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State @@ -398,29 +380,29 @@ def test_update_context_to_tree1(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree2(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.update_context('example') as ctx: m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example') self.assertIsInstance( - out_tree[0][0].states[0]['kernel'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['kernel'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[0][0].states[0]['bias'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['bias'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[1].states[0]['kernel'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['kernel'].get_value(), nnx.graph.ArrayRefOutput ) self.assertIsInstance( - out_tree[1].states[0]['bias'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['bias'].get_value(), nnx.graph.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State @@ -437,29 +419,29 @@ def test_update_context_to_tree2(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree_trivial_prefix(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example', prefix=0) (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True, prefix=0) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example', prefix=0) self.assertIsInstance( - out_tree[0][0].states[0]['kernel'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['kernel'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[0][0].states[0]['bias'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['bias'].get_value(), nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[1].states[0]['kernel'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['kernel'].get_value(), nnx.graph.ArrayRefOutput ) self.assertIsInstance( - out_tree[1].states[0]['bias'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['bias'].get_value(), nnx.graph.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State @@ -475,33 +457,21 @@ def test_update_context_to_tree_trivial_prefix(self): self.assertIs(m3, m1) self.assertIsNot(m_out2, m_out1) - - -class TestMutableArrayNNXTransforms(absltest.TestCase): - @classmethod - def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) - - @classmethod - def tearDownClass(cls): - nnx.use_refs(cls.using_refs) - def test_simple_jit(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) m_out1 = None @nnx.jit def f(m2): nonlocal m_out1 - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = nnx.as_ref_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0))) return m_out1 m_out2 = f(m1) self.assertIsNot(m_out1, m_out2) self.assertIsInstance(m_out2.kernel, nnx.Param) - self.assertIsInstance(m_out2.kernel.raw_value, jax.Ref) + self.assertIsInstance(m_out2.kernel[...], jax.Array) def test_jit_mutable(self): @dataclasses.dataclass @@ -520,18 +490,6 @@ def f(m2: Foo): self.assertIs(m_out1, m1) self.assertIsInstance(m_out1.a, jax.Ref) - - -class TestMutableArray(absltest.TestCase): - @classmethod - def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) - - @classmethod - def tearDownClass(cls): - nnx.use_refs(cls.using_refs) - def test_static(self): class C(nnx.Module): def __init__(self, meta): @@ -554,12 +512,13 @@ def f(x): assert n == 2 def test_variable_creation(self): - v = nnx.Variable(1) + v = nnx.Variable(jnp.array(1), has_ref=True) self.assertEqual(v[...], 1) self.assertTrue(v.has_ref) + self.assertIsInstance(v.get_raw_value(), jax.Ref) def test_variable_metadata(self): - v = nnx.Variable(1, a=2, b=3) + v = nnx.Variable(jnp.array(1), a=2, b=3) self.assertEqual(v.a, 2) self.assertEqual(v.b, 3) @@ -568,10 +527,10 @@ class Params(nnx.Pytree): def __init__(self, din: int, dout: int): self.w = nnx.Param(jnp.zeros((din, dout), jnp.float32)) self.b = nnx.Param(jnp.zeros((dout,), jnp.float32)) - self.count = nnx.Variable(0) + self.count = nnx.Variable(jnp.array(0)) - params: Params params = Params(3, 4) + params = nnx.as_ref_vars(params) paths_leaves, treedef = jax.tree.flatten_with_path(params) paths, leaves = zip(*paths_leaves) @@ -687,12 +646,9 @@ def __call__(self, x): x = jax.random.normal(jax.random.key(0), (5, 2)) y = jnp.ones((5, 4)) - with nnx.use_refs(False): - wrt = lambda path, x: path[-1] == 'w' - model = Model(nnx.Rngs(1)) - optimizer = nnx.Optimizer( - model, tx=optax.adam(1e-3), wrt=wrt - ) + wrt = lambda path, x: path[-1] == 'w' + model = Model(nnx.Rngs(1)) + optimizer = nnx.Optimizer(model, tx=optax.adam(1e-3), wrt=wrt) @jax.jit def train_step(model, optimizer, x, y): @@ -714,11 +670,12 @@ def loss_fn(params): self.assertEqual(model.count[...], 1) self.assertEqual(optimizer.step[...], 1) - def test_optimize_mutable_arrays(self): + @nnx.use_hijax(True) + def test_optimize_hijax(self): class Model(nnx.Module): def __init__(self, rngs): - self.w = jax.new_ref(jax.random.uniform(rngs(), (2, 4))) - self.count = jax.new_ref(jnp.array(0)) + self.w = nnx.Variable(jax.random.uniform(rngs(), (2, 4))) + self.count = nnx.Variable(jnp.array(0)) def __call__(self, x): self.count[...] += 1 @@ -727,10 +684,9 @@ def __call__(self, x): x = jax.random.normal(jax.random.key(0), (5, 2)) y = jnp.ones((5, 4)) - with nnx.use_refs(True): - wrt = lambda path, x: path[-1] == 'w' - model = Model(nnx.Rngs(1)) - optimizer = nnx.Optimizer(model, tx=optax.adam(1e-3), wrt=wrt) + wrt = lambda path, x: path[-1] == 'w' + model = Model(nnx.Rngs(1)) + optimizer = nnx.Optimizer(model, tx=optax.adam(1e-3), wrt=wrt) @jax.jit def train_step(model, optimizer, x, y): @@ -740,7 +696,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.immutable(params)) optimizer.update(params, grads) return loss @@ -748,6 +704,202 @@ def loss_fn(params): self.assertNotEqual(loss, 0.0) +class TestHijaxVariables(absltest.TestCase): + def test_variable_to_hijax(self): + v_low = nnx.Param(jnp.array(1), a='hi') + v_hi = nnx.as_hijax_vars(v_low) + + self.assertTrue(v_hi.is_hijax) + self.assertEqual(v_hi[...], 1) + self.assertIsInstance(v_hi, nnx.Param) + + v_hi[...] = 2 + self.assertEqual(v_hi[...], 2) + + @jax.jit + def set(v_hi, a): + self.assertIsInstance(v_hi, nnx.Param) + v_hi[...] = a + self.assertEqual(v_hi.a, 'hi') + self.assertTrue(v_hi.is_hijax) + v_hi[...] += 5 + return v_hi + 2 + + y = set(v_hi, 10) + self.assertEqual(v_hi[...], 15) + self.assertEqual(y, 17) + + v_low = nnx.immutable(v_hi) + self.assertFalse(v_low.is_mutable) + self.assertIsInstance(v_low, nnx.Param) + + def test_from_metadata(self): + value = 1 + metadata = { + 'a': 'hi', + 'is_hijax': False, + 'has_ref': False, + 'is_mutable': True, + } + v_low = nnx.Param.from_metadata(value, metadata) + self.assertIsInstance(v_low, nnx.Param) + self.assertFalse(v_low.is_hijax) + + metadata['is_hijax'] = True + v_hi = nnx.Param.from_metadata(value, metadata) + self.assertIsInstance(v_hi, nnx.Param) + self.assertTrue(v_hi.is_hijax) + + def test_variable_to_hijax_clean(self): + v_low = nnx.Param(jnp.array([1]), tag='hello') + print() + print(v_low) + assert not v_low.is_hijax + v_hi = nnx.as_hijax_vars(v_low) + v_hi[...] = jnp.array([2]) + assert v_hi.is_hijax + print(v_hi) + assert v_hi[...] == 2 + + @jax.jit + def set(v_hi, a): + v_hi[...] = a + print(v_hi) + assert v_hi.tag == 'hello' + + set(v_hi, 10) + + assert v_hi[...] == 10 + + v_low = nnx.immutable(v_hi) + + assert not v_low.is_hijax and not v_low.is_mutable + assert v_low[...] == 10 + + def test_immutable_variable(self): + v_imm = nnx.Param(jnp.array([1]), is_mutable=False) + assert not v_imm.is_mutable + + with self.assertRaisesRegex( + flax.errors.ImmutableVariableError, + 'Cannot mutate Param as it is marked as immutable', + ): + v_imm[...] = 1 + + def test_pytree_value(self): + v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, is_hijax=True) + + @jax.jit + def inc_and_double(v): + v['a'] += 1 + v['b'] *= 2 + + inc_and_double(v) + + self.assertEqual(v['a'], 1) + self.assertEqual(v['b'], 4) + + def test_hijax_dynamic_structure(self): + x = jnp.ones((4, 5)) + metrics = nnx.Variable({}, is_hijax=True) + + @jax.jit + def f(x, metrics: nnx.Variable): + metrics['x_sum'] = jnp.sum(x) + + self.assertEmpty(metrics) + f(x, metrics) + self.assertIn('x_sum', metrics) + self.assertEqual(metrics['x_sum'], 20) + + def test_hijax_and_pytree(self): + class Foo(nnx.Pytree): + def __init__(self, din, dout, rngs: nnx.Rngs): + self.w = nnx.Param(rngs.uniform((din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = nnx.Variable(0) + + foo = Foo(2, 4, nnx.Rngs(1)) + assert not foo.w.is_hijax + assert not foo.b.is_hijax + + foo = nnx.as_hijax_vars(foo) + + assert foo.w.is_hijax + assert foo.b.is_hijax + + @jax.jit + def forward(foo, x): + foo.count[...] += 1 + return x @ foo.w + foo.b[None] + + x = jnp.ones((1, 2)) + y = forward(foo, x) + assert y.shape == (1, 4) + assert foo.count[...] == 1 + + def test_use_hijax(self): + v_low = nnx.Param(1, a='hi') + self.assertFalse(v_low.is_hijax) + + v_hi = nnx.Param(1, a='hi', is_hijax=True) + self.assertTrue(v_hi.is_hijax) + + with nnx.use_hijax(True): + v2 = nnx.Param(1, a='hi') + self.assertIs(type(v2), nnx.variablelib.HijaxVariable) + self.assertTrue(v2.is_hijax) + + @nnx.use_hijax(True) + def test_hijax_rngs(self): + rngs = nnx.Rngs(0) + self.assertIs(type(rngs.default.key), nnx.variablelib.HijaxVariable) + self.assertIs(type(rngs.default.count), nnx.variablelib.HijaxVariable) + + @jax.jit + def f(rngs: nnx.Rngs): + return rngs() + + k1 = f(rngs) + k2 = f(rngs) + + assert k1 != k2 + + @pytest.mark.skip(reason='not yet supported') + def test_return_hijax_from_transform(self): + @jax.jit + def create_var(): + return nnx.Param(1, is_hijax=True) + + v = create_var() + self.assertTrue(v.is_hijax) + + @pytest.mark.skip(reason='not yet supported') + @nnx.use_hijax(True) + def test_lower(self): + v = nnx.Param(jnp.ones((2, 3))) + + @jax.jit + def f(v): + v[...] += 1 + return v[...] + + e = f.lower(v) + y = e.out_info[2] + self.assertEqual(y.shape, ()) + + @nnx.use_hijax(True) + def test_eval_shape(self): + v = nnx.Param(jnp.array(0)) + + def f(v): + v[...] += 1 + return v[...] + + y = jax.eval_shape(f, v) + + self.assertEqual(y.shape, ()) + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/nn/embed_test.py b/tests/nnx/nn/embed_test.py index 8991c8a35..3120ad84d 100644 --- a/tests/nnx/nn/embed_test.py +++ b/tests/nnx/nn/embed_test.py @@ -59,7 +59,7 @@ def test_nnx_linen_equivalence( NUM_EMBEDDINGS, IN_FEATURES, dtype=dtype, param_dtype=param_dtype ) variables = model.init(key, x) - model_nnx.embedding.value = variables['params']['embedding'] + model_nnx.embedding.set_value(variables['params']['embedding']) out_nnx = model_nnx(x) out = model.apply(variables, x) diff --git a/tests/nnx/nn/linear_test.py b/tests/nnx/nn/linear_test.py index a69c37765..5b7e89749 100644 --- a/tests/nnx/nn/linear_test.py +++ b/tests/nnx/nn/linear_test.py @@ -121,9 +121,9 @@ def test_nnx_linear_equivalence( dot_general=dot_general, ) variables = model.init(key, x) - model_nnx.kernel.value = variables['params']['kernel'] + model_nnx.kernel.set_value(variables['params']['kernel']) if use_bias: - model_nnx.bias.value = variables['params']['bias'] + model_nnx.bias.set_value(variables['params']['bias']) out_nnx = model_nnx(x) out = model.apply(variables, x) @@ -184,10 +184,10 @@ def test_nnx_einsum_equivalence( np.testing.assert_array_equal(out, out_nnx) variables = model.init(key, x) - model_nnx.kernel.value = variables['params']['kernel'] + model_nnx.kernel.set_value(variables['params']['kernel']) if bias_shape is not None: assert model_nnx.bias is not None - model_nnx.bias.value = variables['params']['bias'] + model_nnx.bias.set_value(variables['params']['bias']) out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) diff --git a/tests/nnx/nn/normalization_test.py b/tests/nnx/nn/normalization_test.py index cbc7077e4..9ee8c4bd2 100644 --- a/tests/nnx/nn/normalization_test.py +++ b/tests/nnx/nn/normalization_test.py @@ -241,8 +241,8 @@ def __call__(self, x, *, mask=None): use_fast_variance=use_fast_variance, rngs=rngs, ) - nnx_model.linear.kernel.value = variables['params']['linear']['kernel'] - nnx_model.linear.bias.value = variables['params']['linear']['bias'] + nnx_model.linear.kernel.set_value(variables['params']['linear']['kernel']) + nnx_model.linear.bias.set_value(variables['params']['linear']['bias']) nnx_out = nnx_model(x, mask=mask) assert isinstance(linen_out, jax.Array) @@ -535,20 +535,20 @@ def __call__(self, x): ) # Setup the same weights and batch stats var_params_seq_0 = variables['params']['seq']['layers_0'] - nnx_model.seq.layers[0].kernel.value = var_params_seq_0['kernel'] - nnx_model.seq.layers[0].bias.value = var_params_seq_0['bias'] + nnx_model.seq.layers[0].kernel.set_value(var_params_seq_0['kernel']) + nnx_model.seq.layers[0].bias.set_value(var_params_seq_0['bias']) var_params_seq_2 = variables['params']['seq']['layers_2'] - nnx_model.seq.layers[2].scale.value = var_params_seq_2['scale'] - nnx_model.seq.layers[2].bias.value = var_params_seq_0['bias'] + nnx_model.seq.layers[2].scale.set_value(var_params_seq_2['scale']) + nnx_model.seq.layers[2].bias.set_value(var_params_seq_0['bias']) var_norm_layer = variables['batch_stats']['norm_layer'] nnx_model.norm_layer.batch_stats[ ('layers', 0, 'kernel', 'u') - ].value = var_norm_layer['seq/layers_0/kernel/u'] + ].set_value(var_norm_layer['seq/layers_0/kernel/u']) nnx_model.norm_layer.batch_stats[ ('layers', 0, 'kernel', 'sigma') - ].value = var_norm_layer['seq/layers_0/kernel/sigma'] + ].set_value(var_norm_layer['seq/layers_0/kernel/sigma']) linen_out = linen_model.apply(variables, x, mutable=['batch_stats']) nnx_out = nnx_model(x) diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index c05f9dd4d..d9e3b8faa 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -87,7 +87,7 @@ def test_sharding_propagation(self): self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_names, ('a', 'b')) self.assertEqual( - partition_spec['opt_state'][0]['mu']['kernel'].value, + partition_spec['opt_state'][0]['mu']['kernel'].get_value(), jax.sharding.PartitionSpec('a', 'b'), ) diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index aef3f6a4b..fc8efba20 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -60,7 +60,7 @@ def test_rng_trace_level_constraints(self): def f(): with self.assertRaisesRegex( errors.TraceContextError, - 'Cannot mutate RngStream from a different trace level', + 'Cannot mutate RngCount from a different trace level', ): rngs.params() @@ -78,7 +78,7 @@ def h(): self.assertIsInstance(rngs1, nnx.Rngs) with self.assertRaisesRegex( errors.TraceContextError, - 'Cannot mutate RngStream from a different trace level', + 'Cannot mutate RngCount from a different trace level', ): rngs1.params() diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 4138a9922..646af30d3 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -175,9 +175,9 @@ def __call__(self, x: jax.Array): self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) self.assertEqual(bremoves, [(0, 'layers')]) - @parameterized.product(use_ref=[True, False]) - def test_logical_rules(self, use_ref): - self.enter_context(nnx.use_refs(use_ref)) + @parameterized.product(use_hijax=[True, False]) + def test_logical_rules(self, use_hijax): + self.enter_context(nnx.use_hijax(use_hijax)) class Foo(nnx.Module): def __init__(self): diff --git a/tests/nnx/summary_test.py b/tests/nnx/summary_test.py index c6c748acc..c8b3b7e49 100644 --- a/tests/nnx/summary_test.py +++ b/tests/nnx/summary_test.py @@ -278,7 +278,7 @@ def __init__(self): self.custom_param.set_metadata('custom_obj', Custom()) def __call__(self, x): - return jnp.dot(x, self.hooked_param.value) + self.custom_param.sum() + return jnp.dot(x, self.hooked_param[...]) + self.custom_param.sum() module = Model() # Should not raise yaml.representer.RepresenterError diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 420ccf8d8..de9500842 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -440,13 +440,13 @@ class TestEvalShape(absltest.TestCase): def test_eval_shape(self): abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) self.assertIsInstance(abs_model, nnx.Linear) - self.assertIsInstance(abs_model.kernel.value, jax.ShapeDtypeStruct) + self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) def test_eval_shape_mutable_array(self): - with nnx.use_refs(True): + with nnx.use_hijax(True): abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) self.assertIsInstance(abs_model, nnx.Linear) - self.assertIsInstance(abs_model.kernel.value, jax.ShapeDtypeStruct) + self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) self.assertEqual(abs_model.kernel.shape, (1, 2)) class TestShardMap(absltest.TestCase): @@ -800,7 +800,7 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): loss = jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias ) - l1[0].kernel.value = jnp.array(-1.0) + l1[0].kernel.set_value(jnp.array(-1.0)) m3 = nnx.Linear(2, 3, rngs=nnx.Rngs(2)) return loss, m3 diff --git a/tests/nnx/variable_test.py b/tests/nnx/variable_test.py index 42156f291..791139c40 100644 --- a/tests/nnx/variable_test.py +++ b/tests/nnx/variable_test.py @@ -25,12 +25,12 @@ class TestVariable(absltest.TestCase): def test_pytree(self): r1 = nnx.Param(1) - self.assertEqual(r1.value, 1) + self.assertEqual(r1.get_value(), 1) r2 = jax.tree.map(lambda x: x + 1, r1) - self.assertEqual(r1.value, 1) - self.assertEqual(r2.value, 2) + self.assertEqual(r1.get_value(), 1) + self.assertEqual(r2.get_value(), 2) self.assertIsNot(r1, r2) def test_overloads_module(self): @@ -94,38 +94,47 @@ def test_binary_ops(self): self.assertEqual(v1[...], 5) def test_mutable_array_context(self): - with nnx.use_refs(False): + initial_mode = nnx.using_hijax() + with nnx.use_hijax(False): v = nnx.Variable(jnp.array(1.0)) - self.assertFalse(nnx.using_refs()) - self.assertNotIsInstance(v.raw_value, jax.Ref) + self.assertEqual(nnx.using_hijax(), False) + self.assertNotIsInstance(v[...], jax.Ref) - with nnx.use_refs(True): + with nnx.use_hijax(True): v = nnx.Variable(jnp.array(1.0)) - self.assertTrue(nnx.using_refs()) - self.assertIsInstance(v.raw_value, jax.Ref) + self.assertEqual(nnx.using_hijax(), True) + self.assertIsInstance(v[...], jax.Array) v = nnx.Variable(jnp.array(2.0)) - self.assertNotIsInstance(v.raw_value, jax.Ref) - self.assertFalse(nnx.using_refs()) + self.assertIsInstance(v[...], jax.Array) + self.assertEqual(nnx.using_hijax(), False) - nnx.use_refs(True) + nnx.use_hijax(True) v = nnx.Variable(jnp.array(0.0)) - self.assertTrue(nnx.using_refs()) - self.assertIsInstance(v.raw_value, jax.Ref) + self.assertEqual(nnx.using_hijax(), True) + self.assertIsInstance(v[...], jax.Array) v = nnx.Variable(jnp.array(1.0)) - self.assertFalse(nnx.using_refs()) - self.assertNotIsInstance(v.raw_value, jax.Ref) + self.assertEqual(nnx.using_hijax(), initial_mode) + self.assertIsInstance(v[...], jax.Array) def test_get_set_metadata(self): v = nnx.Variable(jnp.array(1.0)) - self.assertEqual(v.get_metadata(), {}) + self.assertEqual( + v.get_metadata(), + {'is_hijax': False, 'has_ref': False, 'is_mutable': True}, + ) v.set_metadata(a=1, b=2) self.assertEqual(v.get_metadata('a'), 1) self.assertEqual(v.get_metadata('b'), 2) - v.set_metadata({'b': 3, 'c': 4}) - self.assertEqual(v.get_metadata(), {'b': 3, 'c': 4}) + v.set_metadata( + {'b': 3, 'c': 4, 'is_hijax': False, 'has_ref': False, 'is_mutable': True} + ) + self.assertEqual( + v.get_metadata(), + {'b': 3, 'c': 4, 'is_hijax': False, 'has_ref': False, 'is_mutable': True}, + ) self.assertEqual(v.get_metadata('b'), 3) self.assertEqual(v.get_metadata('c'), 4) c = v.get_metadata('c') @@ -140,17 +149,30 @@ def __init__(self): self.p = nnx.Param(jnp.array(1.0)) m = Module() - self.assertTrue('foo' not in m.v.get_metadata()) - self.assertTrue('foo' not in m.p.get_metadata()) + self.assertNotIn('foo', m.v.get_metadata()) + self.assertNotIn('foo', m.p.get_metadata()) nnx.set_metadata(m, foo='bar') - self.assertTrue(m.v.get_metadata() == {'foo': 'bar'}) - self.assertTrue(m.p.get_metadata() == {'foo': 'bar'}) - - self.assertTrue('differentiable' not in m.v.get_metadata()) - self.assertTrue('differentiable' not in m.p.get_metadata()) + # Check that foo was added but the default metadata is still there + v_metadata = m.v.get_metadata() + p_metadata = m.p.get_metadata() + self.assertEqual(v_metadata['foo'], 'bar') + self.assertEqual(p_metadata['foo'], 'bar') + # Check that default metadata is preserved + self.assertIn('is_hijax', v_metadata) + self.assertIn('has_ref', v_metadata) + self.assertIn('is_mutable', v_metadata) + + self.assertNotIn('differentiable', m.v.get_metadata()) + self.assertNotIn('differentiable', m.p.get_metadata()) nnx.set_metadata(m, differentiable=False, only=nnx.Param) - self.assertTrue(m.v.get_metadata() == {'foo': 'bar'}) - self.assertTrue(m.p.get_metadata() == {'foo': 'bar', 'differentiable': False}) + # Check that v still has foo but not differentiable + v_metadata = m.v.get_metadata() + self.assertEqual(v_metadata['foo'], 'bar') + self.assertNotIn('differentiable', v_metadata) + # Check that p has both foo and differentiable + p_metadata = m.p.get_metadata() + self.assertEqual(p_metadata['foo'], 'bar') + self.assertEqual(p_metadata['differentiable'], False) if __name__ == '__main__': absltest.main()