Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects
.venv
venv/
venv.bak/

# used by direnv
.envrc
Expand All @@ -26,3 +29,4 @@ uv.lock

# custom
/tmp-files
.env # test
97 changes: 97 additions & 0 deletions examples/lm1b_nnx/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,103 @@ def test_forward_decode(self):
for output_nnx, output_linen in zip(outputs_nnx, outputs_linen):
assert jnp.allclose(output_nnx, output_linen, atol=1e-5)

def test_forward_eval_set_mode(self):
_, config = get_transformer_config(
axis_rules=default.MeshRules(
embed='model',
mlp='data',
kv=None,
vocab=None,
),
deterministic=True,
decode=False,
)
# Set dropout rates to avoid create dropout states
config.dropout_rate = 0.0
config.attention_dropout_rate = 0.0

model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0)))
_, params_nnx = nnx.split(model_nnx, nnx.Param)

model_linen = TransformerLinen(config)

sample_inputs = random.randint(random.PRNGKey(0), (1, 3), 0, 20)
params_linen = model_linen.init(random.key(0), sample_inputs)['params']

self.transfer_params(config, params_nnx, params_linen)
nnx.update(model_nnx, params_nnx)

det_model = nnx.set_mode(model_nnx, deterministic=True, decode=False)
output_nnx = det_model(sample_inputs)

output_linen: jax.Array = model_linen.apply(
{'params': params_linen}, sample_inputs
)

assert jnp.allclose(output_nnx, output_linen, atol=1e-5)

def test_forward_decode_set_mode(self):
batch_size = 2

_, config = get_transformer_config(
axis_rules=default.MeshRules(
embed='model',
mlp='data',
kv=None,
vocab=None,
),
deterministic=True,
decode=True,
)
# Set dropout rates to avoid create dropout states
config.dropout_rate = 0.0
config.attention_dropout_rate = 0.0

model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0)))
for _path, m in model_nnx.iter_modules():
if isinstance(m, HasCache):
input_shape = (batch_size, config.max_len, config.emb_dim)
m.init_cache(input_shape, dtype=config.dtype)

_, params_nnx, cache_nnx = nnx.split(model_nnx, nnx.Param, nnx.Cache)

model_linen = TransformerLinen(config)

flax_init_inputs = random.randint(
random.PRNGKey(0), (batch_size, config.max_len), 0, config.vocab_size
)
ar_decode_inputs = random.randint(
random.PRNGKey(0), (3, batch_size, 1), 0, config.vocab_size
)
variables = model_linen.init(random.key(0), flax_init_inputs)
params_linen = variables['params']
cache_linen = variables['cache']

self.transfer_params(config, params_nnx, params_linen)
self.transfer_cache(config, cache_nnx, cache_linen)
nnx.update(model_nnx, params_nnx, cache_nnx)
det_model = nnx.set_mode(model_nnx, deterministic=True, decode=True)

outputs_nnx = []
outputs_linen = []

for inputs in ar_decode_inputs:
output_nnx = det_model(inputs)
outputs_nnx.append(output_nnx)

output_linen: jax.Array
for inputs in ar_decode_inputs:
output_linen, updates = model_linen.apply(
{'params': params_linen, 'cache': cache_linen},
inputs,
mutable=['cache'],
)
cache_linen = updates['cache']
outputs_linen.append(output_linen)

for output_nnx, output_linen in zip(outputs_nnx, outputs_linen):
assert jnp.allclose(output_nnx, output_linen, atol=1e-5)


if __name__ == '__main__':
absltest.main()
6 changes: 6 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
from .helpers import TrainState as TrainState
from .module import M as M
from .module import Module as Module
from .module import set_mode as set_mode
from .module import set_mode_info as set_mode_info
from .module import train_mode as train_mode
from .module import eval_mode as eval_mode
from .module import set_attributes as set_attributes
from .module import iter_children as iter_children, iter_modules as iter_modules
from .graph import merge as merge
from .graph import UpdateContext as UpdateContext
Expand All @@ -59,6 +64,7 @@
from .graph import state as state
from .graph import graphdef as graphdef
from .graph import iter_graph as iter_graph
from .graph import recursive_map as recursive_map
from .graph import find_duplicates as find_duplicates
from .graph import call as call
from .graph import SplitContext as SplitContext
Expand Down
52 changes: 51 additions & 1 deletion flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,7 +2592,7 @@ def pop(
return states


def clone(node: Node, variables: bool = True) -> Node:
def clone(node: Node, /, *, variables: bool = True) -> Node:
"""Create a deep copy of the given graph node.

Example usage::
Expand Down Expand Up @@ -2940,6 +2940,52 @@ def _iter_graph(

yield path_parts, node

def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /):
node = clone(node, variables=False)
path_parts: PathParts = ()
visited: set[int] = set()
results: dict[int, tp.Any] = {}
return _recursive_map(f, node, path_parts, visited, results)


def _recursive_map(
f: tp.Callable[[PathParts, tp.Any], tp.Any],
node: tp.Any,
path: PathParts,
visited: set[int],
results: dict[int, tp.Any],
) -> tp.Any:
node_id = id(node)
if node_id in visited:
if node_id in results:
return results[node_id]
path_str = '/'.join(map(str, path))
raise ValueError(
f"Found cycle in the graph at path '{path_str}'. Node of type"
f' {type(node)} has already been visited but has not been returned yet.'
)
node_impl = get_node_impl(node)
if (
type(node_impl) is GraphNodeImpl
or isinstance(node, Variable)
or is_array_ref(node)
):
visited.add(node_id)
if node_impl is not None:
for key, value in node_impl.node_dict(node).items():
new_value = _recursive_map(f, value, (*path, key), visited, results)
if new_value is not value:
if node_impl.set_key is not None and value is not new_value:
node_impl.set_key(node, key, new_value)
else:
raise ValueError(
f"Cannot update key '{key}' for node of type '{type(node)}'"
' because the node does not support mutation.'
)

new_node = f(path, node)
results[node_id] = new_node
return new_node

def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) -> list[list[PathParts]]:
"""Finds duplicate nodes or node leaves in the given node.
Expand Down Expand Up @@ -3110,12 +3156,16 @@ def _unflatten_pytree(
pop_key=None,
)

def _list_set_key(x: list[tp.Any], key: int, value: tp.Any):
x[key] = value

# common pytrees
# list
register_pytree_node_type(
list,
flatten=lambda x: (list(enumerate(x)), None),
unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore
set_key=_list_set_key,
)
# tuple
register_pytree_node_type(
Expand Down
Loading
Loading