diff --git a/.gitignore b/.gitignore index 3778842d2..f818ef8e2 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,9 @@ flaxlib_src/build flaxlib_src/builddir flaxlib_src/dist flaxlib_src/subprojects +.venv +venv/ +venv.bak/ # used by direnv .envrc @@ -26,3 +29,4 @@ uv.lock # custom /tmp-files +.env # test diff --git a/examples/lm1b_nnx/models_test.py b/examples/lm1b_nnx/models_test.py index 47fa2fed6..b9733a9d5 100644 --- a/examples/lm1b_nnx/models_test.py +++ b/examples/lm1b_nnx/models_test.py @@ -291,6 +291,103 @@ def test_forward_decode(self): for output_nnx, output_linen in zip(outputs_nnx, outputs_linen): assert jnp.allclose(output_nnx, output_linen, atol=1e-5) + def test_forward_eval_set_mode(self): + _, config = get_transformer_config( + axis_rules=default.MeshRules( + embed='model', + mlp='data', + kv=None, + vocab=None, + ), + deterministic=True, + decode=False, + ) + # Set dropout rates to avoid create dropout states + config.dropout_rate = 0.0 + config.attention_dropout_rate = 0.0 + + model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0))) + _, params_nnx = nnx.split(model_nnx, nnx.Param) + + model_linen = TransformerLinen(config) + + sample_inputs = random.randint(random.PRNGKey(0), (1, 3), 0, 20) + params_linen = model_linen.init(random.key(0), sample_inputs)['params'] + + self.transfer_params(config, params_nnx, params_linen) + nnx.update(model_nnx, params_nnx) + + det_model = nnx.set_mode(model_nnx, deterministic=True, decode=False) + output_nnx = det_model(sample_inputs) + + output_linen: jax.Array = model_linen.apply( + {'params': params_linen}, sample_inputs + ) + + assert jnp.allclose(output_nnx, output_linen, atol=1e-5) + + def test_forward_decode_set_mode(self): + batch_size = 2 + + _, config = get_transformer_config( + axis_rules=default.MeshRules( + embed='model', + mlp='data', + kv=None, + vocab=None, + ), + deterministic=True, + decode=True, + ) + # Set dropout rates to avoid create dropout states + config.dropout_rate = 0.0 + config.attention_dropout_rate = 0.0 + + model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0))) + for _path, m in model_nnx.iter_modules(): + if isinstance(m, HasCache): + input_shape = (batch_size, config.max_len, config.emb_dim) + m.init_cache(input_shape, dtype=config.dtype) + + _, params_nnx, cache_nnx = nnx.split(model_nnx, nnx.Param, nnx.Cache) + + model_linen = TransformerLinen(config) + + flax_init_inputs = random.randint( + random.PRNGKey(0), (batch_size, config.max_len), 0, config.vocab_size + ) + ar_decode_inputs = random.randint( + random.PRNGKey(0), (3, batch_size, 1), 0, config.vocab_size + ) + variables = model_linen.init(random.key(0), flax_init_inputs) + params_linen = variables['params'] + cache_linen = variables['cache'] + + self.transfer_params(config, params_nnx, params_linen) + self.transfer_cache(config, cache_nnx, cache_linen) + nnx.update(model_nnx, params_nnx, cache_nnx) + det_model = nnx.set_mode(model_nnx, deterministic=True, decode=True) + + outputs_nnx = [] + outputs_linen = [] + + for inputs in ar_decode_inputs: + output_nnx = det_model(inputs) + outputs_nnx.append(output_nnx) + + output_linen: jax.Array + for inputs in ar_decode_inputs: + output_linen, updates = model_linen.apply( + {'params': params_linen, 'cache': cache_linen}, + inputs, + mutable=['cache'], + ) + cache_linen = updates['cache'] + outputs_linen.append(output_linen) + + for output_nnx, output_linen in zip(outputs_nnx, outputs_linen): + assert jnp.allclose(output_nnx, output_linen, atol=1e-5) + if __name__ == '__main__': absltest.main() diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 3ef14ce4c..f93133a43 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -47,6 +47,11 @@ from .helpers import TrainState as TrainState from .module import M as M from .module import Module as Module +from .module import set_mode as set_mode +from .module import set_mode_info as set_mode_info +from .module import train_mode as train_mode +from .module import eval_mode as eval_mode +from .module import set_attributes as set_attributes from .module import iter_children as iter_children, iter_modules as iter_modules from .graph import merge as merge from .graph import UpdateContext as UpdateContext @@ -59,6 +64,7 @@ from .graph import state as state from .graph import graphdef as graphdef from .graph import iter_graph as iter_graph +from .graph import recursive_map as recursive_map from .graph import find_duplicates as find_duplicates from .graph import call as call from .graph import SplitContext as SplitContext diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 45c0dede5..ba0d1d0cd 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -2592,7 +2592,7 @@ def pop( return states -def clone(node: Node, variables: bool = True) -> Node: +def clone(node: Node, /, *, variables: bool = True) -> Node: """Create a deep copy of the given graph node. Example usage:: @@ -2940,6 +2940,52 @@ def _iter_graph( yield path_parts, node +def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /): + node = clone(node, variables=False) + path_parts: PathParts = () + visited: set[int] = set() + results: dict[int, tp.Any] = {} + return _recursive_map(f, node, path_parts, visited, results) + + +def _recursive_map( + f: tp.Callable[[PathParts, tp.Any], tp.Any], + node: tp.Any, + path: PathParts, + visited: set[int], + results: dict[int, tp.Any], +) -> tp.Any: + node_id = id(node) + if node_id in visited: + if node_id in results: + return results[node_id] + path_str = '/'.join(map(str, path)) + raise ValueError( + f"Found cycle in the graph at path '{path_str}'. Node of type" + f' {type(node)} has already been visited but has not been returned yet.' + ) + node_impl = get_node_impl(node) + if ( + type(node_impl) is GraphNodeImpl + or isinstance(node, Variable) + or is_array_ref(node) + ): + visited.add(node_id) + if node_impl is not None: + for key, value in node_impl.node_dict(node).items(): + new_value = _recursive_map(f, value, (*path, key), visited, results) + if new_value is not value: + if node_impl.set_key is not None and value is not new_value: + node_impl.set_key(node, key, new_value) + else: + raise ValueError( + f"Cannot update key '{key}' for node of type '{type(node)}'" + ' because the node does not support mutation.' + ) + + new_node = f(path, node) + results[node_id] = new_node + return new_node def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) -> list[list[PathParts]]: """Finds duplicate nodes or node leaves in the given node. @@ -3110,12 +3156,16 @@ def _unflatten_pytree( pop_key=None, ) +def _list_set_key(x: list[tp.Any], key: int, value: tp.Any): + x[key] = value + # common pytrees # list register_pytree_node_type( list, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore + set_key=_list_set_key, ) # tuple register_pytree_node_type( diff --git a/flax/nnx/module.py b/flax/nnx/module.py index aa32a7edf..8015a2c1c 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -14,6 +14,7 @@ from __future__ import annotations +import inspect import typing as tp import jax @@ -428,6 +429,261 @@ def eval(self, **attributes): ) +def set_mode(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, **kwargs) -> A: + """Creates a new node with static attributes updated according to + ``**kwargs``. The new node contains references to jax arrays in the original + node. If a kwarg is not found in any module, this method raises a ValueError. + Class set_mode functions should return any unused kwargs. + + Example:: + + >>> from flax import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5, deterministic=False) + ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) + ... + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (False, False) + >>> new_block = nnx.set_mode(block, deterministic=True, use_running_average=True) + >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average + (True, True) + + ``Filter``'s can be used to set the attributes of specific Modules:: + + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> new_block = nnx.set_mode(block, only=nnx.Dropout, deterministic=True) + >>> # Only the dropout will be modified + >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average + (True, False) + + Args: + node: the object to create a copy of. + only: Filters to select the Modules to set the attributes of. + **kwargs: The attributes to set. + """ + predicate = filterlib.to_predicate(only) + + counts = {k: 0 for k in kwargs} + counts["_set_mode_calls"] = 0 + + def _set_mode_fn(path, node): + if hasattr(node, 'set_mode') and predicate(path, node): + counts["_set_mode_calls"] += 1 + unused = node.set_mode(**kwargs) + for k in unused: + counts[k] += 1 + return node + + out = graph.recursive_map(_set_mode_fn, node) + + if raise_if_not_found: + set_mode_calls = counts.pop("_set_mode_calls") + unused_keys = [k for k, v in counts.items() if v == set_mode_calls] + if unused_keys: + raise ValueError(f"Unused keys found in set_mode: {unused_keys}") + + return out + +def set_mode_individual_info(cls: Module) -> tuple[str, str]: + """Provides info about ``set_mode`` arguments for an individual module without + it's submodules. Returns type information followed by the docstring. + """ + return str(inspect.signature(cls.set_mode)), inspect.getdoc(cls.set_mode) + + +def _set_mode_info_parse_types(s: str): + last_quote_ind = s.find("' =") + if s.startswith("'") and last_quote_ind >= 0: + s = s[1:last_quote_ind] + s[last_quote_ind+1:] + return s + +def set_mode_info(node: A, /, *, only: filterlib.Filter = ...) -> str: + """Provides information about the ``set_mode`` arguments for a module and all + submodules. + + Example:: + + >>> from flax import nnx + ... + >>> class CustomModel(nnx.Module): + ... def __init__(self, *, rngs): + ... self.bn = nnx.BatchNorm(10, rngs=rngs) + ... self.mha = nnx.MultiHeadAttention(10, 10, 10, 10, 10, rngs=rngs) + ... self.drop = nnx.Dropout(0.1, rngs=rngs) + ... + >>> model = CustomModel(rngs=nnx.Rngs(0)) + >>> nnx.set_mode_info(model) + BatchNorm: + use_running_average: bool | None = None + if True, the stored batch statistics will be used instead of computing the batch statistics on the input. + Dropout: + deterministic: bool | None = None + if True, disables dropout masking. + MultiHeadAttention: + deterministic: bool | None = None + if True, the module is set to deterministic mode. + decode: bool | None = None + if True, the module is set to decode mode. + batch_size: int | Shape | None = None + the batch size to use for the cache. + max_length: int | None = None + the max length to use for the cache. + + Args: + node: the object to display ``set_mode`` information for. + only: Filters to select the Modules to display information for. + """ + predicate = filterlib.to_predicate(only) + classes: set[Module] = set() + + def _set_mode_fn(path, node): + if hasattr(node, 'set_mode') and predicate(path, node): + classes.add(node.__class__) + return node + + graph.recursive_map(_set_mode_fn, node) + + classes = sorted(list(classes), key=lambda x: x.__qualname__) + out_str = [] + for c in classes: + cls_name = c.__qualname__ + out_str.append(f"{cls_name}:") + sig_str, doc_str = set_mode_individual_info(c) + + # Start with sig_string + last_ind = sig_str.find(", **kwargs") + split_args = sig_str[7:last_ind].split(", ") + + # Look at the docstring + args_ind = doc_str.find("Args:") + if args_ind == -1: + args_ind = doc_str.find("Arguments:") + lines = doc_str[args_ind:].split("\n")[1:] + + if args_ind != -1: + # reconstruct the docstring descriptions + doc_dict, curr_str_list, varname = dict(), [], None + for l in lines: + # Check if tabbed over twice + if not l.startswith(" "*4): + doc_dict[varname] = " ".join(curr_str_list) + curr_str_list = [] + colon_ind = l.find(": ") + varname = l[2:colon_ind] + curr_str_list.append(l[colon_ind+2:]) + else: + curr_str_list.append(l[4:]) + doc_dict[varname] = " ".join(curr_str_list) + + for arg in split_args: + # Retrieve the var name and type + varname, vartype = arg.split(": ") + vartype = _set_mode_info_parse_types(vartype) + out_str.append(f" {varname}: {vartype}") + + # Retrieve the docstring + if args_ind != -1: + out_str.append(" "*4 + doc_dict[varname]) + + return "\n".join(out_str) + + +def train_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A: + """Creates a new node set to training mode. + + ``train`` uses ``set_mode`` to recursively set attributes ``deterministic=False`` + and ``use_running_average=False`` of all nested Modules that have these attributes. + Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` + Modules. + + Example:: + + >>> from flax import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... # initialize Dropout and BatchNorm in eval mode + ... self.dropout = nnx.Dropout(0.5, deterministic=True) + ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) + ... + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (True, True) + >>> train_block = nnx.train_mode(block) + >>> train_block.dropout.deterministic, train_block.batch_norm.use_running_average + (False, False) + + Args: + **kwargs: additional attributes passed to ``set_attributes``. + """ + return set_mode( + node, + only=only, + raise_if_not_found=False, + deterministic=False, + use_running_average=False, + **kwargs, + ) + + +def eval_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A: + """Creates a new node set to evaluation mode. + + ``eval`` uses ``set_mode`` to recursively set attributes ``deterministic=True`` + and ``use_running_average=True`` of all nested Modules that have these attributes. + Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` + Modules. + + Example:: + + >>> from flax import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5) + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) + ... + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (False, False) + >>> eval_block = nnx.eval_mode(block) + >>> eval_block.dropout.deterministic, eval_block.batch_norm.use_running_average + (True, True) + + Args: + **kwargs: additional attributes passed to ``set_mode``. + """ + return set_mode( + node, + only=only, + raise_if_not_found=False, + deterministic=True, + use_running_average=True, + **kwargs, + ) + + +def set_attributes( + node: A, /, *, only: filterlib.Filter = ..., **attributes +) -> A: + predicate = filterlib.to_predicate(only) + + def _set_attributes_fn(path, node): + if predicate(path, node): + for name, value in attributes.items(): + if hasattr(node, name): + setattr(node, name, value) + return node + + return graph.recursive_map(_set_attributes_fn, node) + + def first_from(*args: tp.Optional[A], error_msg: str) -> A: """Return the first non-None argument. diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 34e386e34..fcec9b19c 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -604,6 +604,57 @@ def __call__( out = self.out(x) return out + def set_mode( + self, + deterministic: bool | None = None, + decode: bool | None = None, + batch_size: int | Shape | None = None, + max_length: int | None = None, + **kwargs, + ) -> dict: + """ + Args: + train: if True, the module is set to training mode. + deterministic: if True, the module is set to deterministic mode. + decode: if True, the module is set to decode mode. + batch_size: the batch size to use for the cache. + max_length: the max length to use for the cache. + """ + if deterministic is not None: + self.deterministic = deterministic + + if decode is not None: + self.decode = decode + if ( + not hasattr(self, 'cached_key') + or not hasattr(self, 'cached_value') + or not hasattr(self, 'cache_index') + ): + if batch_size is None: + raise TypeError( + "'batch_size' must be provided when initializing cache." + ) + if max_length is None: + raise TypeError( + "'max_length' must be provided when initializing cache." + ) + self.init_cache2(batch_size, max_length, dtype=self.dtype) + return kwargs + + def init_cache2( + self, batch_size: int | Shape, max_length: int, dtype: Dtype | None = None + ): + if dtype is None: + dtype = self.dtype + if isinstance(batch_size, int): + batch_size = (batch_size,) + + cache_shape = (*batch_size, max_length, self.num_heads, self.head_dim) + self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype)) + self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype)) + self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32)) + + def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): """Initializes cache for fast autoregressive decoding. When ``decode=True``, this method must be called first before performing diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 17a164ebf..01b869fbe 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -392,6 +392,20 @@ def __call__( self.epsilon, ) + def set_mode( + self, + use_running_average: bool | None = None, + **kwargs, + ) -> dict: + """ + Args: + use_running_average: if True, the stored batch statistics will be + used instead of computing the batch statistics on the input. + """ + if use_running_average is not None: + self.use_running_average = use_running_average + return kwargs + class LayerNorm(Module): """Layer normalization (https://arxiv.org/abs/1607.06450). diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index ab365ef5c..a538ae94b 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -73,7 +73,7 @@ def __init__( rate: float, *, broadcast_dims: Sequence[int] = (), - deterministic: bool = False, + deterministic: bool | None = None, rng_collection: str = 'dropout', rngs: rnglib.Rngs | rnglib.RngStream | None = None, ): @@ -153,3 +153,16 @@ def __call__( mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + def set_mode( + self, + deterministic: bool | None = None, + **kwargs, + ) -> dict: + """ + Args: + deterministic: if True, disables dropout masking. + """ + if deterministic is not None: + self.deterministic = deterministic + return kwargs diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index d2fcff79a..f250e26c8 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -276,7 +276,7 @@ def test_pure_nnx_submodule(self): class NNXLayer(nnx.Module): def __init__(self, dim, dropout, rngs): self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs) - self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.dropout = nnx.Dropout(dropout, deterministic=False, rngs=rngs) self.count = nnx.Intermediate(jnp.array([0.])) def __call__(self, x): # Required check to avoid state update in `init()`. Can this be avoided? diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 8e827bd24..652839706 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -270,7 +270,7 @@ def test_nnx_to_linen_multiple_rngs(self): class NNXInner(nnx.Module): def __init__(self, din, dout, rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.5, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w[...]) @@ -423,7 +423,7 @@ def test_nnx_to_linen_pytree_structure_consistency(self): class NNXInner(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.5, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w) @@ -476,7 +476,7 @@ def __init__(self, din, dout, dropout_rate, rngs): self.w = nnx.Param( nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=('in', 'out') )(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs) + self.dropout = nnx.Dropout(rate=dropout_rate, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 523ff31fa..6f681fba3 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1107,6 +1107,52 @@ def __init__(self, rngs: nnx.Rngs): self.assertLen(duplicates, 1) self.assertEqual(duplicates[0], [('a',), ('c',)]) + def test_resursive_map(self): + class Foo(nnx.Pytree): + def __init__(self, d): + self.d = d + + foo1 = Foo(10) + foo2 = Foo(20) + bar = [foo1, foo2, foo1] + n = 0 + + def inc_d(path, node): + nonlocal n + if isinstance(node, Foo): + n += 1 + node.d += 1 + return node + + bar2 = nnx.recursive_map(inc_d, bar) + self.assertIs(bar2[0], bar2[2]) + self.assertEqual(bar2[0].d, 11) + self.assertEqual(bar2[1].d, 21) + self.assertEqual(n, 2) + + def test_resursive_map_replace(self): + class Foo(nnx.Pytree): + def __init__(self, d): + self.d = d + + foo1 = Foo(10) + foo2 = Foo(20) + bar = [foo1, foo2, foo1] + n = 0 + + def swap(path, node): + nonlocal n + if isinstance(node, Foo): + n += 1 + node = Foo(-node.d) + return node + + bar2 = nnx.recursive_map(swap, bar) + self.assertIs(bar2[0], bar2[2]) + self.assertEqual(bar2[0].d, -10) + self.assertEqual(bar2[1].d, -20) + self.assertEqual(n, 2) + def test_graphdef_hash_with_sequential(self): rngs = nnx.Rngs(0) net = nnx.Sequential( diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 67eaa71c5..bb02ad8a0 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -28,6 +28,59 @@ class TestIntegration(absltest.TestCase): + + def test_basic_example(self): + class Model(nnx.Module): + + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + train_model = nnx.set_mode( + model, deterministic=False, use_running_average=False + ) + eval_model = nnx.set_mode( + model, deterministic=True, use_running_average=True + ) + optimizer = nnx.Optimizer(train_model, optax.adam(1e-3), wrt=nnx.Param) + + self.assertEqual(train_model.dropout.deterministic, False) + self.assertEqual(train_model.bn.use_running_average, False) + self.assertEqual(eval_model.dropout.deterministic, True) + self.assertEqual(eval_model.bn.use_running_average, True) + self.assertIs(train_model.dropout.rngs.count, eval_model.dropout.rngs.count) + + @nnx.jit # automatic state management for JAX transforms + def train_step(model, optimizer, x, y): + def loss_fn(model): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(model, grads) # in-place updates + + return loss + + @nnx.jit + def eval_step(model, x, y): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + x = jax.random.normal(jax.random.key(0), (8, 2)) + y = jax.random.normal(jax.random.key(1), (8, 3)) + + train_step(train_model, optimizer, x, y) + self.assertEqual(train_model.dropout.rngs.count.value, 1) + eval_step(eval_model, x, y) + self.assertEqual(train_model.dropout.rngs.count.value, 1) + def test_shared_modules(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): @@ -78,6 +131,56 @@ def loss_fn(model: Model): assert model.block1.linear.bias is not None assert model.block1.bn is not model.block2.bn + def test_shared_modules_set_mode(self): + class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, rngs): + self.linear = linear + self.bn = nnx.BatchNorm(2, rngs=rngs) + + def __call__(self, x): + x = self.linear(x) + x = self.bn(x) + return nnx.relu(x) + + class Model(nnx.Module): + def __init__(self, *, rngs): + shared = nnx.Linear(2, 2, rngs=rngs) + self.block1 = Block(shared, rngs=rngs) + self.block2 = Block(shared, rngs=rngs) + + def __call__(self, x): + x = self.block1(x) + x = self.block2(x) + return x + + @nnx.jit + def train_step(model: Model, x, y): + @nnx.grad + def loss_fn(model: Model): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + nnx.update( + model, + jax.tree.map( + lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + ), + ) + + model = Model(rngs=nnx.Rngs(0)) + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + new_model = nnx.set_mode(model, use_running_average=False) + + for _i in range(3): + train_step(model, x, y) + + assert new_model.block1.linear is new_model.block2.linear + assert new_model.block1.linear.bias is not None + assert new_model.block1.bn is not new_model.block2.bn + def test_shared_modules_pure(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): @@ -137,6 +240,65 @@ def loss_fn(model: Model): assert model.block1.linear.bias is model.block2.linear.bias assert model.block1.bn is not model.block2.bn + def test_shared_modules_pure_set_mode(self): + class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): + self.linear = linear + self.bn = nnx.BatchNorm(2, rngs=rngs) + + def __call__(self, x): + x = self.linear(x) + x = self.bn(x) + return nnx.relu(x) + + class Model(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + shared = nnx.Linear(2, 2, rngs=rngs) + self.block1 = Block(shared, rngs=rngs) + self.block2 = Block(shared, rngs=rngs) + + def __call__(self, x): + x = self.block1(x) + x = self.block2(x) + return x + + @jax.jit + def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): + model = nnx.merge(graphdef, state) + new_model = nnx.set_mode(model, use_running_average=False) + + @nnx.grad + def loss_fn(model: Model): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(new_model) + nnx.update( + new_model, + jax.tree.map( + lambda w, g: w - 0.1 * g, nnx.state(new_model, nnx.Param), grads + ), + ) + + return nnx.split(new_model) + + graphdef: nnx.GraphDef[Model] + graphdef, state = nnx.split(Model(rngs=nnx.Rngs(0))) + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _ in range(3): + graphdef, state = train_step(state, graphdef, x, y) + + model = nnx.merge(graphdef, state) + + assert model.block1.linear.bias is not None + assert model.block2.linear.bias is not None + assert model.block1.linear.kernel is model.block2.linear.kernel + assert model.block1.linear.bias is model.block2.linear.bias + assert model.block1.bn is not model.block2.bn + def test_stateful_example(self): class State(nnx.Variable[A]): pass @@ -303,7 +465,7 @@ class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) - self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): @@ -319,7 +481,7 @@ def train_step(x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) - return ((model(x) - y) ** 2).mean() # call methods directly + return ((model(x) - y) ** 2).mean() loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) optimizer.update(model, grads) # in-place updates diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 64bf1dda2..0261a5e37 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -649,6 +649,49 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + def test_set_mode(self): + class Block(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False) + self.batch_norm = nnx.BatchNorm( + 10, use_running_average=False, rngs=rngs + ) + + block = Block(2, 5, rngs=nnx.Rngs(0)) + assert block.dropout.deterministic == False + assert block.batch_norm.use_running_average == False + + new_block = nnx.set_mode(block, deterministic=True, use_running_average=True) + assert new_block.dropout.deterministic == True + assert new_block.batch_norm.use_running_average == True + assert new_block.linear.kernel is block.linear.kernel + + block = Block(2, 5, rngs=nnx.Rngs(0)) + new_block = nnx.set_mode(block, only=nnx.Dropout, deterministic=True) + # Only the dropout will be modified + assert new_block.dropout.deterministic == True + assert new_block.batch_norm.use_running_average == False + + def test_set_mode_error(self): + class Block(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False) + self.batch_norm = nnx.BatchNorm( + 10, use_running_average=False, rngs=rngs + ) + #TODO: Get this done + block = Block(2, 5, rngs=nnx.Rngs(0)) + + with self.assertRaisesRegex( + ValueError, + ( + "Unused keys found in set_mode: \\['unknown'\\]" + ), + ): + nnx.set_mode(block, deterministic=True, use_running_average=True, unknown=True) + def test_cloud_pickle(self): class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index c8a9d55a7..61440056a 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -89,6 +89,50 @@ def __call__(self, x, sow_weights=False): intermediates = nnx.pop(module, nnx.Intermediate) assert not intermediates # empty + def test_multihead_sow_attention_weights_set_mode(self): + class Model(nnx.Module): + attention_kwargs: dict + + def __init__(self, attention_kwargs, rng): + self.attention_layers = nnx.data([ + nnx.MultiHeadAttention(**attention_kwargs, rngs=rng) for i in range(3) + ]) + + def __call__(self, x, sow_weights=False): + x = self.attention_layers[0](x, sow_weights=sow_weights) + x = self.attention_layers[1](x) + x = self.attention_layers[2](x, sow_weights=sow_weights) + return x + + rng = nnx.Rngs(0) + x = jnp.ones((4, 6, 8)) + + module = Model( + dict( + in_features=8, + num_heads=8, + kernel_init=nnx.initializers.ones_init(), + bias_init=nnx.initializers.zeros_init(), + deterministic=False, + ), + rng, + ) + new_module = nnx.set_mode(module, decode=False) + + _ = new_module(x, True) + intermediates = nnx.pop(new_module, nnx.Intermediate) + assert intermediates['attention_layers'][0]['attention_weights'][ + 0 + ].shape == (4, 8, 6, 6) + assert 1 not in intermediates['attention_layers'] + assert intermediates['attention_layers'][2]['attention_weights'][ + 0 + ].shape == (4, 8, 6, 6) + + _ = new_module(x) + intermediates = nnx.pop(new_module, nnx.Intermediate) + assert not intermediates # empty + def test_autoregressive_decode_with_x64(self): with enable_x64(): x = jnp.ones((1, 4, 4)) diff --git a/tests/nnx/nn/recurrent_test.py b/tests/nnx/nn/recurrent_test.py index 034baff2f..9747d4ac6 100644 --- a/tests/nnx/nn/recurrent_test.py +++ b/tests/nnx/nn/recurrent_test.py @@ -589,7 +589,7 @@ def __init__( **kwargs, ) self.recurrent_dropout = nnx.Dropout( - rate=dropout_rate, rng_collection='recurrent_dropout', rngs=rngs + rate=dropout_rate, deterministic=False, rng_collection='recurrent_dropout', rngs=rngs ) def __call__(self, carry, x): @@ -615,7 +615,7 @@ def __init__( dropout_rate=recurrent_dropout_rate, ) self.lstm = nnx.RNN(cell, broadcast_rngs='recurrent_dropout') - self.dropout = nnx.Dropout(dropout_rate, rngs=rngs) + self.dropout = nnx.Dropout(dropout_rate, deterministic=False, rngs=rngs) self.dense = nnx.Linear( in_features=hidden_features, out_features=1, rngs=rngs ) diff --git a/tests/nnx/nn/stochastic_test.py b/tests/nnx/nn/stochastic_test.py index a9a13e51f..97d28de0f 100644 --- a/tests/nnx/nn/stochastic_test.py +++ b/tests/nnx/nn/stochastic_test.py @@ -87,3 +87,26 @@ def test_dropout_arg_override(self): match='`deterministic` is False, but no `rngs` argument was provided to Dropout', ): m(x) + + def test_dropout_arg_override_set_mode(self): + m = nnx.Dropout(rate=0.5) + x = jnp.ones((1, 10)) + + # deterministic call arg provided + m(x, deterministic=True) + # deterministic constructor arg provided + new_m = nnx.set_mode(m, deterministic=True) + y = new_m(x) + # both deterministic call and constructor arg provided + with pytest.raises(AssertionError): + np.testing.assert_allclose( + y, new_m(x, deterministic=False, rngs=nnx.Rngs(dropout=0)) + ) + # no rng arg provided + # m.set_attributes(deterministic=False) + new_m = nnx.set_mode(m, deterministic=False) + with pytest.raises( + ValueError, + match='`deterministic` is False, but no `rngs` argument was provided to Dropout', + ): + new_m(x) diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index aef3f6a4b..6c1bc3110 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -176,7 +176,7 @@ def test_reseed(self): class Model(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) - self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(self.linear(x)) diff --git a/tests/nnx/summary_test.py b/tests/nnx/summary_test.py index 73e8eb9f8..ec5722fc8 100644 --- a/tests/nnx/summary_test.py +++ b/tests/nnx/summary_test.py @@ -26,7 +26,7 @@ class Block(nnx.Module): def __init__(self, din, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) - self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) def forward(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index d79aa35c8..93e3d099a 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1583,6 +1583,38 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) + def test_complex_set_mode(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes)) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.scan(in_axes=(state_axes, nnx.Carry)) + def __call__(self, x: jax.Array): + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + module = MLP(rngs=nnx.Rngs(0)) + new_module = nnx.set_mode(module, deterministic=False, use_running_average=False) + + assert new_module.linear.kernel.shape == (5, 3, 3) + assert new_module.linear.bias.shape == (5, 3) + assert new_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, _ = new_module(x) + + assert y.shape == (1, 3) + def test_complex_broadcast_dropout(self): state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) @@ -1616,6 +1648,39 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) + def test_complex_broadcast_dropout_set_mode(self): + state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5, only='params') + @nnx.vmap(in_axes=(state_axes, state_axes)) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.split_rngs(splits=5, only='params') + @nnx.scan(in_axes=(state_axes, nnx.Carry)) + def __call__(self, x: jax.Array): + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + module = MLP(rngs=nnx.Rngs(params=0, dropout=1)) + new_module = nnx.set_mode(module, deterministic=False, use_running_average=False) + + assert new_module.linear.kernel.shape == (5, 3, 3) + assert new_module.linear.bias.shape == (5, 3) + assert new_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, _ = new_module(x) + + assert y.shape == (1, 3) + def test_complex_decorator(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) @@ -1651,6 +1716,41 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) assert out is None + def test_complex_decorator_set_mode(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class Block(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + def __init__(self, rngs: nnx.Rngs): + self.d = 3 + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.scan(in_axes=(state_axes, nnx.Carry)) + def __call__(self, x: jax.Array): + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + module = Block(rngs=nnx.Rngs(0)) + new_module = nnx.set_mode(module, deterministic=False, use_running_average=False) + + assert new_module.d == 3 + assert new_module.linear.kernel.shape == (5, 3, 3) + assert new_module.linear.bias.shape == (5, 3) + assert new_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, out = new_module(x) + + assert y.shape == (1, 3) + assert out is None + def test_scan_with_sharding(self): test = self state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) @@ -1781,7 +1881,7 @@ def __init__(self, input_size, hidden_size, rngs): self.linear = nnx.Linear( hidden_size + input_size, hidden_size, rngs=rngs ) - self.drop = nnx.Dropout(0.1, rngs=rngs) + self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs) self.hidden_size = hidden_size def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]: @@ -2402,7 +2502,7 @@ def test_example(self): class Model(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) - self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) def __call__(self, x):