diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index cfce0a6e..84998132 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: - python-version: [3.9, '3.10', 3.11] + python-version: ['3.10', 3.11, 3.12] JAX_ENABLE_X64: [0, 1] runs-on: ubuntu-latest @@ -32,10 +32,10 @@ jobs: uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - - uses: actions/checkout@v4.1.1 + - uses: actions/checkout@v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.3.0 with: python-version: ${{ matrix.python-version }} @@ -54,7 +54,7 @@ jobs: JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest -n auto --cov=neural_tangents --cov-report=xml --cov-report=term - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4.0.1 + uses: codecov/codecov-action@v4.6.0 with: file: ./coverage.xml diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 75885a54..a9cbf565 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: - python-version: [3.9, '3.10', 3.11] + python-version: ['3.10', 3.11, 3.12] JAX_ENABLE_X64: [0] runs-on: macos-latest @@ -32,10 +32,10 @@ jobs: uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - - uses: actions/checkout@v4.1.1 + - uses: actions/checkout@v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.3.0 with: python-version: ${{ matrix.python-version }} @@ -54,7 +54,7 @@ jobs: JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest -n auto --cov=neural_tangents --cov-report=xml --cov-report=term - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4.0.1 + uses: codecov/codecov-action@v4.6.0 with: file: ./coverage.xml diff --git a/.github/workflows/pytype.yml b/.github/workflows/pytype.yml index 766f8d1e..d409a5c7 100644 --- a/.github/workflows/pytype.yml +++ b/.github/workflows/pytype.yml @@ -27,10 +27,10 @@ jobs: uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - - uses: actions/checkout@v4.1.1 + - uses: actions/checkout@v4.2.2 - name: Set up Python 3.10 - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.3.0 with: python-version: '3.10' diff --git a/.github/workflows/sketching.yml b/.github/workflows/sketching.yml index a6ccc6e1..27f2ded0 100644 --- a/.github/workflows/sketching.yml +++ b/.github/workflows/sketching.yml @@ -31,10 +31,10 @@ jobs: uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - - uses: actions/checkout@v4.1.1 + - uses: actions/checkout@v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.3.0 with: python-version: ${{ matrix.python-version }} @@ -53,7 +53,7 @@ jobs: JAX_ENABLE_X64=${{ matrix.JAX_ENABLE_X64 }} PYTHONHASHSEED=0 pytest experimental/tests/ -n auto --cov=experimental/ --cov-report=xml --cov-report=term - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4.0.1 + uses: codecov/codecov-action@v4.6.0 with: file: ./coverage.xml diff --git a/README.md b/README.md index 70b12951..820b087c 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ from jax.example_libraries import stax init_fn, apply_fn = stax.serial( stax.Dense(512), stax.Relu, stax.Dense(512), stax.Relu, - stax.Dense(1) + stax.Dense(1), ) key = random.PRNGKey(1) @@ -123,7 +123,7 @@ from neural_tangents import stax init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Relu(), stax.Dense(512), stax.Relu(), - stax.Dense(1) + stax.Dense(1), ) key1, key2 = random.split(random.PRNGKey(1)) diff --git a/examples/datasets.py b/examples/datasets.py index f48cea88..7588922c 100644 --- a/examples/datasets.py +++ b/examples/datasets.py @@ -44,7 +44,7 @@ def get_dataset( permute_train=False, do_flatten_and_normalize=True, data_dir=None, - input_key='image' + input_key='image', ): """Download, parse and process a dataset to unit scale and one-hot labels.""" # Need this following http://cl/378185881 to prevent GPU test breakages. @@ -133,9 +133,10 @@ def embed_glove(xs, glove_path, max_sentence_length=1000, mask_constant=1000.): xs = list(map(_decode, xs)) tokenizer = tf.keras.preprocessing.text.Tokenizer() tokenizer.fit_on_texts(np.concatenate(xs)) - glove_embedding_layer = _get_glove_embedding_layer(tokenizer, - glove_path, - max_sentence_length) + glove_embedding_layer = _get_glove_embedding_layer( + tokenizer, + glove_path, + ) def embed(x): # Replace strings with sequences of integer tokens. @@ -147,7 +148,8 @@ def embed(x): x_tok, max_sentence_length, padding='post', - truncating='post') + truncating='post', + ) # Replace integer tokens with word embeddings. x_emb = glove_embedding_layer(x_tok).numpy() @@ -160,7 +162,7 @@ def embed(x): return map(embed, xs) -def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length): +def _get_glove_embedding_layer(tokenizer, glove_path): """Get a Keras embedding layer for a given GloVe embeddings. Adapted from https://keras.io/examples/pretrained_word_embeddings/. @@ -172,9 +174,6 @@ def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length): glove_path: path to the GloVe embedding file. - max_sentence_length: - pad/truncate embeddings to this length. - Returns: Keras embedding layer for a given GloVe embeddings. """ @@ -212,8 +211,8 @@ def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length): embedding_layer = tf.keras.layers.Embedding( num_words, embedding_dim, embeddings_initializer=tf.keras.initializers.Constant(emb_mat), - input_length=max_sentence_length, - trainable=False) + trainable=False, + ) return embedding_layer diff --git a/examples/function_space.py b/examples/function_space.py index 7953449a..c9d95070 100644 --- a/examples/function_space.py +++ b/examples/function_space.py @@ -57,7 +57,7 @@ def main(unused_argv): opt_init, opt_apply, get_params = optimizers.sgd(_LEARNING_RATE) state = opt_init(params) - # Create an mse loss function and a gradient function. + # Create a mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * jnp.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) diff --git a/examples/imdb.py b/examples/imdb.py index 6daa55c8..bd1a8059 100644 --- a/examples/imdb.py +++ b/examples/imdb.py @@ -99,7 +99,7 @@ def main(*args, use_dummy_data: bool = False, **kwargs) -> None: def _get_dummy_data( - mask_constant: float + mask_constant: float, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Return dummy data for when downloading embeddings is not feasible.""" n_train, n_test = 6, 6 diff --git a/neural_tangents/_src/batching.py b/neural_tangents/_src/batching.py index 3998d6d8..13822a43 100644 --- a/neural_tangents/_src/batching.py +++ b/neural_tangents/_src/batching.py @@ -42,7 +42,7 @@ """ from functools import partial -from typing import Any, Callable, Iterable, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, TypeVar import warnings import jax @@ -57,7 +57,6 @@ from jax.tree_util import tree_all from jax.tree_util import tree_flatten -from jax.tree_util import tree_map from jax.tree_util import tree_unflatten import numpy as np @@ -77,7 +76,7 @@ def batch( kernel_fn: _KernelFn, batch_size: int = 0, device_count: int = -1, - store_on_device: bool = True + store_on_device: bool = True, ) -> _KernelFn: """Returns a function that computes a kernel in batches over all devices. @@ -114,7 +113,7 @@ def batch( kernel by batching over the dataset in parallel with the specified `batch_size` using `device_count` devices. """ - # TODO(romann): find a way to avoid reading requirements. + # TODO: find a way to avoid reading requirements. input_req = getattr(kernel_fn, 'input_req', {}) dropout_in_analytic_kernel = input_req.get('use_dropout', False) use_multidevice = device_count > 0 or (device_count == -1 and @@ -143,13 +142,13 @@ def batch( def _scan( f: Callable[[_Carry, _Input], tuple[_Carry, _Output]], init: _Carry, - xs: Iterable[_Input] + xs: Iterable[_Input], ) -> tuple[_Carry, _Output]: """Implements an unrolled version of scan. Based on :obj:`jax.lax.scan` and has a similar API. - TODO(schsam): We introduce this function because lax.scan currently has a + TODO: We introduce this function because lax.scan currently has a higher peak memory usage than the unrolled version. We will aim to swap this out for lax.scan when issue #1273 and related have been resolved. """ @@ -161,13 +160,13 @@ def _scan( carry, y = f(carry, x) ys += [y] - return carry, tree_map(lambda *y: jnp.stack(y), *ys) + return carry, jax.tree.map(lambda *y: jnp.stack(y), *ys) def _flatten_batch_dimensions( k: jnp.ndarray, is_parallel: bool, - discard_axis: Optional[int] = None + discard_axis: int | None = None, ) -> jnp.ndarray: """Takes a kernel that has been evaluated in batches and flattens.""" @@ -195,7 +194,7 @@ def _flatten_batch_dimensions( def _flatten_kernel_dict( k: dict[str, Any], x2_is_none: bool, - is_parallel: bool + is_parallel: bool, ) -> dict[str, Any]: if 'nngp' in k: # We only use `batch_size` to compute `shape1` and `shape2` for the batch. @@ -247,10 +246,10 @@ def _flatten_kernel_dict( @utils.nt_tree_fn(nargs=1) def _flatten_kernel( - k: Kernel, + k: Kernel | np.ndarray | jnp.ndarray, x2_is_none: bool, - is_parallel: bool -) -> Kernel: + is_parallel: bool, +) -> Kernel | np.ndarray | jnp.ndarray: """Flattens a kernel array or a `Kernel` along the batch dimension.""" if hasattr(k, '_asdict') and hasattr(k, '_replace'): @@ -263,15 +262,17 @@ def _flatten_kernel( elif isinstance(k, (np.ndarray, jnp.ndarray)): return _flatten_batch_dimensions(k, is_parallel) - raise TypeError(f'Expected kernel to be either a namedtuple, `Kernel`, or ' - f'`jnp.ndarray`, got {type(k)}.') + raise TypeError( + f'Expected kernel to be either a namedtuple, `Kernel`, or ' + f'`jnp.ndarray`, got {type(k)}.' + ) @utils.nt_tree_fn(nargs=1) def _reshape_kernel_for_pmap( k: Kernel, device_count: int, - n1_per_device: int + n1_per_device: int, ) -> Kernel: cov2 = k.cov2 if cov2 is None: @@ -288,7 +289,8 @@ def _reshape_kernel_for_pmap( nngp, ntk, cov1 = [ jnp.reshape(x, (device_count, n1_per_device,) + x.shape[1:]) for x in - (k.nngp, k.ntk, k.cov1)] + (k.nngp, k.ntk, k.cov1) + ] return k.replace( nngp=nngp, @@ -307,14 +309,14 @@ def _reshape_kernel_for_pmap( @utils.nt_tree_fn() def _set_cov2_to_none(k: _ArrayOrKernel) -> _ArrayOrKernel: if isinstance(k, Kernel): - k = k.replace(cov2=None) # pytype: disable=attribute-error # jax-ndarray + k = k.replace(cov2=None) return k def _serial( kernel_fn: _KernelFn, batch_size: int, - store_on_device: bool = True + store_on_device: bool = True, ) -> _KernelFn: """Returns a function that computes a kernel in batches serially. @@ -360,14 +362,14 @@ def kernel_fn(x1, x2=None, *args, **kwargs): def serial_fn_x1( x1: NTTree[jnp.ndarray], - x2: Optional[NTTree[Optional[jnp.ndarray]]] = None, + x2: NTTree[jnp.ndarray | None] | None = None, *args, **kwargs ) -> NTTree[Kernel]: x2_is_none = utils.all_none(x2) if x2_is_none: - # TODO(schsam): Only compute the upper triangular part of the kernel. + # TODO: Only compute the upper triangular part of the kernel. x2 = x1 @utils.nt_tree_fn(reduce=lambda x: x[0]) @@ -425,12 +427,12 @@ def serial_fn_kernel(k: NTTree[Kernel], *args, **kwargs) -> NTTree[Kernel]: def get_n1_n2(k: NTTree[Kernel]) -> tuple[int, ...]: if utils.is_list_or_tuple(k): - # TODO(schsam): We might want to check for consistency here, but I can't + # TODO: We might want to check for consistency here, but I can't # imagine a case where we could get inconsistent kernels. return get_n1_n2(k[0]) if isinstance(k, Kernel): - return k.nngp.shape[:2] # pytype: disable=attribute-error + return k.nngp.shape[:2] raise TypeError(type(Kernel), Kernel) @@ -466,7 +468,7 @@ def row_fn(_, n1): return _, _scan(col_fn, n1, (n2s, kwargs_np2))[1] def col_fn(n1, n2): - # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will + # NOTE: If we end up wanting to enable jit-of-batch then we will # probably have to change this to dynamic slicing. n1, kwargs1 = n1 n2, kwargs2 = n2 @@ -486,10 +488,12 @@ def col_fn(n1, n2): return flatten(k, cov2_is_none) @utils.wraps(kernel_fn) - def serial_fn(x1_or_kernel: Union[NTTree[jnp.ndarray], NTTree[Kernel]], - x2: Optional[NTTree[Optional[jnp.ndarray]]] = None, - *args, - **kwargs) -> NTTree[Kernel]: + def serial_fn( + x1_or_kernel: NTTree[jnp.ndarray] | NTTree[Kernel], + x2: NTTree[jnp.ndarray | None] | None = None, + *args, + **kwargs, + ) -> NTTree[Kernel]: if utils.is_nt_tree_of(x1_or_kernel, (np.ndarray, jnp.ndarray)): return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs) elif utils.is_nt_tree_of(x1_or_kernel, Kernel): @@ -575,7 +579,7 @@ def _get_n_per_device(n1): def parallel_fn_x1(x1, x2=None, *args, **kwargs): x2_is_none = utils.all_none(x2) if x2_is_none: - # TODO(schsam): Only compute the upper triangular part of the kernel. + # TODO: Only compute the upper triangular part of the kernel. x2 = x1 def get_batch_size(x): @@ -648,9 +652,9 @@ def _get_n_batches_and_batch_sizes( n1: int, n2: int, batch_size: int, - device_count: int + device_count: int, ) -> tuple[int, int, int, int]: - # TODO(romann): if dropout batching works for different batch sizes, relax. + # TODO: if dropout batching works for different batch sizes, relax. max_serial_batch_size = np.gcd(n1, n2) // device_count n2_batch_size = min(batch_size, max_serial_batch_size) @@ -662,7 +666,7 @@ def _get_n_batches_and_batch_sizes( n1_batch_size = n2_batch_size * device_count n1_batches, ragged = divmod(n1, n1_batch_size) if ragged: - # TODO(schsam): Relax this constraint. + # TODO: Relax this constraint. msg = ('Number of rows of kernel must divide batch size. Found n1 = {} ' 'and batch size = {}.').format(n1, n1_batch_size) if device_count > 1: @@ -672,7 +676,7 @@ def _get_n_batches_and_batch_sizes( n2_batches, ragged = divmod(n2, n2_batch_size) if ragged: - # TODO(schsam): Relax this constraint. + # TODO: Relax this constraint. raise ValueError(('Number of columns of kernel must divide batch ' 'size. Found n2 = {} ' 'and batch size = {}').format(n2, n2_batch_size)) @@ -682,7 +686,7 @@ def _get_n_batches_and_batch_sizes( def _is_np_ndarray(x) -> bool: if x is None: return False - return tree_all(tree_map( + return tree_all(jax.tree.map( lambda y: isinstance(y, (np.ndarray, jnp.ndarray)), x)) @@ -719,7 +723,7 @@ def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable: Raises: An error if `kwargs` have a `jnp.ndarray`. - TODO(romann): treat `jnp.ndarray`s in `kwargs` when JAX allows it. See + TODO: treat `jnp.ndarray`s in `kwargs` when JAX allows it. See https://github.com/google/jax/issues/912 """ key = (f, device_count) @@ -727,18 +731,18 @@ def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable: if device_count == -1: device_count = jax.local_device_count() - # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`. + # TODO: adapt this when JAX allows `axis_in` for `pmap`. def broadcast(arg: jnp.ndarray) -> jnp.ndarray: if device_count == 0: return arg return jnp.broadcast_to(arg, (device_count,) + arg.shape) @utils.wraps(f) - def f_pmapped(x_or_kernel: Union[jnp.ndarray, Kernel], *args, **kwargs): + def f_pmapped(x_or_kernel: jnp.ndarray | Kernel, *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} - # TODO(romann): treat `jnp.ndarray`s in `kwargs` when JAX allows it. + # TODO: treat `jnp.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `jnp.ndarray`s from other arguments. for i, arg in enumerate(args): @@ -776,7 +780,7 @@ def _f(_x_or_kernel, *_args_np, **_kwargs_np): cache[_key] = _f # Broadcast `jnp.ndarray` arguments and apply the new function to them. - args_np = tree_map(broadcast, args_np) + args_np = jax.tree.map(broadcast, args_np) return _f(x_or_kernel, *args_np, **kwargs_np) return f_pmapped diff --git a/neural_tangents/_src/empirical.py b/neural_tangents/_src/empirical.py index 1c41abf6..18241547 100644 --- a/neural_tangents/_src/empirical.py +++ b/neural_tangents/_src/empirical.py @@ -104,7 +104,7 @@ import enum import functools import operator -from typing import Callable, Iterable, KeysView, Optional, TypeVar, Union +from typing import Callable, Iterable, KeysView, TypeVar import warnings import jax @@ -134,7 +134,6 @@ import jax.numpy as jnp from jax.tree_util import tree_flatten -from jax.tree_util import tree_map from jax.tree_util import tree_reduce from jax.tree_util import tree_structure from jax.tree_util import tree_transpose @@ -150,6 +149,7 @@ from .utils.typing import Axes from .utils.typing import EmpiricalGetKernelFn from .utils.typing import EmpiricalKernelFn +from .utils.typing import Get from .utils.typing import PyTree from .utils.typing import VMapAxes from .utils.typing import VMapAxisTriple @@ -243,7 +243,7 @@ def f_tayl(p, *args, **kwargs): def empirical_nngp_fn( f: ApplyFn, trace_axes: Axes = (-1,), - diagonal_axes: Axes = () + diagonal_axes: Axes = (), ) -> EmpiricalKernelFn: """Returns a function to draw a single sample the NNGP of a given network `f`. @@ -260,7 +260,7 @@ def empirical_nngp_fn( `diagonal_axes=()` to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. - For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal + For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principle the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you. @@ -306,7 +306,7 @@ def empirical_nngp_fn( """ def nngp_fn( x1: PyTree, - x2: Optional[PyTree], + x2: PyTree | None, params: PyTree, **apply_fn_kwargs ) -> PyTree: @@ -352,7 +352,7 @@ def contract(out1: jnp.ndarray, out2: jnp.ndarray) -> jnp.ndarray: dot = _dot_general(out1, out2, trace_axes, diagonal_axes) return dot / utils.size_at(out1, trace_axes) - return tree_map(contract, out1, out2) + return jax.tree.map(contract, out1, out2) return nngp_fn @@ -371,7 +371,7 @@ class NtkImplementation(enum.IntEnum): (or `0`) evaluates FLOPs of all other methods at compilation time, and selects the fastest method. However, at the time it only works correctly on TPUs, and on CPU/GPU can return wrong results, which is why - it is not the default. TODO(romann): revisit based on http://b/202218145. + it is not the default. TODO: revisit based on http://b/202218145. JACOBIAN_CONTRACTION: (or `1`) computes the NTK as the outer product of two Jacobians, each @@ -462,7 +462,7 @@ class NtkImplementation(enum.IntEnum): """ -_DEFAULT_NTK_FWD: Optional[bool] = None +_DEFAULT_NTK_FWD: bool | None = None """Says whether to use forward mode in `STRUCTURED_DERIVATIVES` (`3`) Jacobians. Useful for debugging and testing, but for best performance should be set to @@ -476,15 +476,15 @@ def _empirical_auto_ntk_fn(**kwargs) -> EmpiricalGetKernelFn: Returns wrong FLOPS on CPU and GPU when JITting. - TODO(romann): revisit based on http://b/202218145. + TODO: revisit based on http://b/202218145. """ cache = {} def ntk_fn( x1: PyTree, - x2: Optional[PyTree], + x2: PyTree | None, params: PyTree, - **apply_fn_kwargs + **apply_fn_kwargs, ) -> jnp.ndarray: """Computes a single sample of the automatic empirical NTK. @@ -515,7 +515,7 @@ def ntk_fn( 2) `diagonal_axes` are present only once. All other axes are present twice. """ - shapes = tree_map(jnp.shape, (x1, x2, params, apply_fn_kwargs)) + shapes = jax.tree.map(jnp.shape, (x1, x2, params, apply_fn_kwargs)) shapes = _to_tuple_tree(shapes) if shapes not in cache: @@ -544,7 +544,7 @@ def _jacobian_contraction_ntk_fn( trace_axes: Axes, diagonal_axes: Axes, vmap_axes: VMapAxes, - **kwargs + **kwargs, ) -> EmpiricalKernelFn: """Compute NTK by directly instantiating Jacobians and contracting.""" @@ -560,13 +560,13 @@ def contract(x, y): contract_axes = _trace_axes + param_axes return _dot_general(x, y, contract_axes, _diagonal_axes) / size - return tree_reduce(operator.add, tree_map(contract, j1, j2)) + return tree_reduce(operator.add, jax.tree.map(contract, j1, j2)) def ntk_fn( x1: PyTree, - x2: Optional[PyTree], + x2: PyTree | None, params: PyTree, - **apply_fn_kwargs + **apply_fn_kwargs, ) -> jnp.ndarray: """Computes a single sample of the empirical NTK (jacobian outer product). @@ -612,7 +612,7 @@ def j_fn(x, *args): j1 = j_fn(x1, *args1) j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1 - ntk = tree_map(sum_and_contract, fx1, j1, j2) + ntk = jax.tree.map(sum_and_contract, fx1, j1, j2) return ntk return ntk_fn @@ -623,13 +623,13 @@ def _ntk_vector_products_ntk_fn( trace_axes: Axes, diagonal_axes: Axes, vmap_axes: VMapAxes, - **kwargs + **kwargs, ) -> EmpiricalKernelFn: """Compute NTK via NTK-vector products.""" def ntk_fn( x1: PyTree, - x2: Optional[PyTree], + x2: PyTree | None, params: PyTree, **apply_fn_kwargs ) -> jnp.ndarray: @@ -676,7 +676,7 @@ def delta_vjp(delta): fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) - ntk = tree_map(lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) + ntk = jax.tree.map(lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk @@ -694,8 +694,10 @@ def delta_vjp(delta): _add(fx_axis, _ndim(fx1))) ntk = get_ntk(x1, x2, *args1, *args2) - ntk = tree_map(lambda x: _trace_and_diagonal(x, trace_axes, diagonal_axes), - ntk) + ntk = jax.tree.map( + lambda x: _trace_and_diagonal(x, trace_axes, diagonal_axes), + ntk, + ) return ntk return ntk_fn @@ -708,7 +710,7 @@ def _structured_derivatives_ntk_fn( vmap_axes: VMapAxes, _j_rules: bool, _s_rules: bool, - _fwd: Optional[bool] + _fwd: bool | None, ) -> EmpiricalKernelFn: """Compute NTK by using structured derivatives.""" @@ -716,8 +718,8 @@ def sum_and_contract( fx1: jnp.ndarray, fx2: jnp.ndarray, fx_axis, - df_dys_1: list[Union[jnp.ndarray, Zero]], - df_dys_2: list[Union[jnp.ndarray, Zero]], + df_dys_1: list[jnp.ndarray | Zero], + df_dys_2: list[jnp.ndarray | Zero], dy_dws_1: list[tuple[jnp.ndarray, rules.Structure]], dy_dws_2: list[tuple[jnp.ndarray, rules.Structure]], dtype: jnp.dtype @@ -782,7 +784,7 @@ def contract(df_dys_1, df_dys_2, dy_dws_1, dy_dws_2): for i, (id_1, id_2) in enumerate(zip(s1.out_diagonal, s2.out_diagonal)): - # TODO(romann): compute based on array dimensions. + # TODO: compute based on array dimensions. axis_shift = -100_000 # Huge axis shift to ensure unique axis ids. axis_id = (-axis_shift -df_dy_1.ndim - df_dy_2.ndim - dy_dw_1.ndim @@ -826,7 +828,7 @@ def contract(df_dys_1, df_dys_2, dy_dws_1, dy_dws_2): ntk = tree_reduce( operator.add, - tree_map( + jax.tree.map( contract, df_dys_1, df_dys_2, dy_dws_1, dy_dws_2, is_leaf= @@ -841,9 +843,9 @@ def contract(df_dys_1, df_dys_2, dy_dws_1, dy_dws_2): def ntk_fn( x1: PyTree, - x2: Optional[PyTree], + x2: PyTree | None, params: PyTree, - **apply_fn_kwargs + **apply_fn_kwargs, ) -> jnp.ndarray: """Computes a single sample of the structured derivatives NTK. @@ -894,7 +896,7 @@ def j_fn(x, *args): df_dys_1, dy_dws_1) fx_axis, dtype = _get_fx_axis_and_dtype(fx1, fx_axis, params) - ntk = tree_map( + ntk = jax.tree.map( functools.partial( sum_and_contract, dy_dws_1=dy_dws_1, @@ -925,10 +927,10 @@ def empirical_ntk_fn( trace_axes: Axes = (-1,), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, - implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, + implementation: NtkImplementation | int = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, - _fwd: Optional[bool] = _DEFAULT_NTK_FWD, + _fwd: bool | None = _DEFAULT_NTK_FWD, ) -> EmpiricalKernelFn: r"""Returns a function to draw a single sample the NTK of a given network `f`. @@ -952,7 +954,7 @@ def empirical_ntk_fn( `diagonal_axes=()` to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. - For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal + For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principle the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you. @@ -1076,10 +1078,10 @@ def empirical_kernel_fn( trace_axes: Axes = (-1,), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, - implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, + implementation: NtkImplementation | int = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, - _fwd: Optional[bool] = _DEFAULT_NTK_FWD, + _fwd: bool | None = _DEFAULT_NTK_FWD, ) -> EmpiricalGetKernelFn: r"""Returns a function that computes single draws from NNGP and NT kernels. @@ -1093,7 +1095,7 @@ def empirical_kernel_fn( `diagonal_axes=()` to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. - For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal + For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principle the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you. @@ -1224,10 +1226,10 @@ def empirical_kernel_fn( @utils.get_namedtuple('EmpiricalKernel') def kernel_fn( x1: PyTree, - x2: Optional[PyTree], - get: Union[None, str, tuple[str, ...]], + x2: PyTree | None, + get: Get, params: PyTree, - **apply_fn_kwargs + **apply_fn_kwargs, ) -> PyTree: """Computes a single sample of the empirical kernel of type `get`. @@ -1284,9 +1286,9 @@ def kernel_fn( def empirical_ntk_vp_fn( f: ApplyFn, x1: PyTree, - x2: Optional[PyTree], + x2: PyTree | None, params: PyTree, - **apply_fn_kwargs + **apply_fn_kwargs, ) -> Callable[[PyTree], PyTree]: """Returns an NTK-vector product function. @@ -1388,7 +1390,7 @@ def ntk_vp_fn(cotangents: PyTree) -> PyTree: def _trace_and_diagonal( ntk: jnp.ndarray, trace_axes: Axes, - diagonal_axes: Axes + diagonal_axes: Axes, ) -> jnp.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. @@ -1441,12 +1443,12 @@ def _trace_and_diagonal( def _dict_of_tree_to_tree_of_dict( out_dict: dict[str, PyTree], - get: tuple[str, ...] + get: tuple[str, ...], ) -> PyTree: # If the elements of an output dict are tuples then change the representation # to be a tuple of dicts instead. This occurs when the output of a network is # a parallel layer. - return tree_map(lambda *x: dict((g, v) for g, v in zip(get, x)), + return jax.tree.map(lambda *x: dict((g, v) for g, v in zip(get, x)), *[out_dict[g] for g in get]) @@ -1456,7 +1458,7 @@ def _get_f_params( x_axis: PyTree, fx_axis: PyTree, kw_axes: dict[str, PyTree], - **apply_fn_kwargs + **apply_fn_kwargs, ) -> Callable[[PyTree], PyTree]: x = _expand_dims(x, x_axis) @@ -1478,7 +1480,7 @@ def _get_args( params: PyTree, vmap_axes: VMapAxes, x1: PyTree, - x2: PyTree + x2: PyTree, ): kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) @@ -1501,7 +1503,7 @@ def _get_f1_f2( kw_axes: dict[str, PyTree], args: tuple, x1: PyTree, - x2: Optional[PyTree] + x2: PyTree | None, ) -> tuple[Callable[[PyTree], PyTree], Callable[[PyTree], PyTree]]: args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} @@ -1517,7 +1519,7 @@ def _get_f1_f2( def _check_einsum_no_broadcast( arrays: list[jnp.ndarray], - dims: list[list[int]] + dims: list[list[int]], ): """Check that all matching einsum contracting axis sizes are equal. @@ -1555,35 +1557,35 @@ def expand(x: jnp.ndarray) -> jnp.ndarray: def _expand_dims( - x: Union[None, PyTree, UndefinedPrimal], - axis: Optional[PyTree] -) -> Optional[PyTree]: + x: None | PyTree | UndefinedPrimal, + axis: PyTree | None, +) -> PyTree | None: if axis is None or x is None or isinstance(x, UndefinedPrimal): return x - return tree_map(_expand_dims_array, x, axis) + return jax.tree.map(_expand_dims_array, x, axis) -def _add(x: Optional[PyTree], y: Optional[PyTree]) -> Optional[PyTree]: +def _add(x: PyTree | None, y: PyTree | None) -> PyTree | None: if x is None or y is None: return None - return tree_map(operator.add, x, y) + return jax.tree.map(operator.add, x, y) def _sub(x: PyTree, y: PyTree) -> PyTree: - return tree_map(operator.sub, x, y) + return jax.tree.map(operator.sub, x, y) def _div(x: PyTree, y: int) -> PyTree: - return tree_map(lambda x: x / y, x) + return jax.tree.map(lambda x: x / y, x) -def _squeeze(x: PyTree, axis: Optional[PyTree]) -> PyTree: +def _squeeze(x: PyTree, axis: PyTree | None) -> PyTree: if axis is None: return x def squeeze( x: jnp.ndarray, - axis: Union[None, int, tuple[int, ...]] + axis: None | int | tuple[int, ...], ) -> jnp.ndarray: """`np.squeeze` analog working with 0-sized axes.""" if isinstance(axis, int): @@ -1606,20 +1608,20 @@ def squeeze( return jnp.squeeze(x, non_zero_axes) - return tree_map(squeeze, x, axis) + return jax.tree.map(squeeze, x, axis) def _ndim(x: PyTree) -> PyTree: - return tree_map(lambda x: x.ndim, x) + return jax.tree.map(lambda x: x.ndim, x) def _mod( - x: Optional[PyTree], - y: PyTree + x: PyTree | None, + y: PyTree, ) -> PyTree: if x is None: return None - return tree_map(operator.mod, x, y) + return jax.tree.map(operator.mod, x, y) def _diagonal(ntk: PyTree, fx: PyTree) -> PyTree: @@ -1631,10 +1633,10 @@ def _diagonal(ntk: PyTree, fx: PyTree) -> PyTree: def _canonicalize_axes( - vmap_axes: Optional[VMapAxes], + vmap_axes: VMapAxes, x: PyTree, fx: PyTree, - **kwargs + **kwargs, ) -> VMapAxisTriple: if isinstance(vmap_axes, tuple) and len(vmap_axes) == 3: x_axis, fx_axis, kw_axes = vmap_axes @@ -1642,13 +1644,13 @@ def _canonicalize_axes( x_axis, fx_axis, kw_axes = vmap_axes, vmap_axes, {} if isinstance(x_axis, int): - x_axis = tree_map(lambda _: x_axis, x) + x_axis = jax.tree.map(lambda _: x_axis, x) if isinstance(fx_axis, int): - fx_axis = tree_map(lambda _: fx_axis, fx) + fx_axis = jax.tree.map(lambda _: fx_axis, fx) if isinstance(kw_axes, int): - kw_axes = tree_map(lambda _: kw_axes, kwargs) + kw_axes = jax.tree.map(lambda _: kw_axes, kwargs) x_axis = _mod(x_axis, _ndim(x)) fx_axis = _mod(fx_axis, _ndim(fx)) @@ -1690,7 +1692,7 @@ def _get_dims( df_dy_2: jnp.ndarray, ndim: int, trace_axes: Axes, - diagonal_axes: Axes + diagonal_axes: Axes, ) -> tuple[list[int], list[int], list[int]]: df_dy_dims_1 = list(range(df_dy_1.ndim)) df_dy_dims_2 = list(range(df_dy_1.ndim, df_dy_1.ndim + df_dy_2.ndim)) @@ -1720,15 +1722,15 @@ def _is_abstract_array(x) -> bool: def _vmap(f: Callable, in_axes, out_axes, squeeze_out: bool = True) -> Callable: """An expand-then-squeeze `vmap` for `f` expecting/returning batch dims.""" - in_axes_plus_1 = tree_map(lambda x: x if x in (None, -1) else x + 1, in_axes) + in_axes_plus_1 = jax.tree.map(lambda x: x if x in (None, -1) else x + 1, in_axes) @utils.wraps(f) def f_vmapped(*args): - args = tree_map( + args = jax.tree.map( _expand_dims, args, in_axes_plus_1, is_leaf=_is_abstract_array) out = vmap(f, in_axes, out_axes)(*args) if squeeze_out: - out_axes_plus_1 = tree_map( + out_axes_plus_1 = jax.tree.map( lambda x: x if x in (None, -1) else x + 1, out_axes) out = _squeeze(out, out_axes_plus_1) return out @@ -1738,9 +1740,9 @@ def f_vmapped(*args): def _get_fx_axis_and_dtype(fx, fx_axis, params: PyTree): if fx_axis is None: - fx_axis = tree_map(lambda x: None, fx) + fx_axis = jax.tree.map(lambda x: None, fx) # Set the default type to be the least common type ancestor. - dtypes, _ = tree_flatten(tree_map(jnp.dtype, params)) + dtypes, _ = tree_flatten(jax.tree.map(jnp.dtype, params)) if not dtypes: dtype = None else: @@ -1749,16 +1751,16 @@ def _get_fx_axis_and_dtype(fx, fx_axis, params: PyTree): def _unravel_dfs(dfs: PyTree, params: PyTree, y: PyTree) -> PyTree: - dfs = tree_map(functools.partial(_unravel_array_into_pytree, y, 0), dfs) + dfs = jax.tree.map(functools.partial(_unravel_array_into_pytree, y, 0), dfs) if tree_structure(dfs).num_leaves > 0: - dfs = tree_transpose(tree_structure(tree_map(lambda x, y: [x] * len(y), + dfs = tree_transpose(tree_structure(jax.tree.map(lambda x, y: [x] * len(y), params, dfs)), tree_structure(y), dfs) if tree_structure(dfs).num_leaves == 0: - dfs = tree_map(lambda x: dfs, y) + dfs = jax.tree.map(lambda x: dfs, y) return dfs @@ -1773,7 +1775,7 @@ def _get_df_dys_and_dy_dws( params: PyTree, _j_rules: bool, _s_rules: bool, - _fwd: Optional[bool] + _fwd: bool | None, ) -> tuple[PyTree, PyTree]: """Computes primitive output cotangents (`df/dy`) and Jacobians (`dy/dw`).""" def primals_out_and_pullback(mode: _MODE) -> PyTree: @@ -1796,8 +1798,8 @@ def _get_primals_out_and_pullback( mode: _MODE, _j_rules: bool, _s_rules: bool, - _fwd: Optional[bool], - *primals_in: PyTree + _fwd: bool | None, + *primals_in: PyTree, ) -> tuple[PyTree, Callable]: """Adapted from `jax.interpreters.ad`. @@ -1809,7 +1811,7 @@ def _get_primals_out_and_pullback( fn_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(fn), in_tree) - # TODO(romann): handle call primitives more gracefully. + # TODO: handle call primitives more gracefully. with jax.disable_jit(): outs = ad.linearize(fn_flat, *primals_in_flat, has_aux=False) @@ -1836,9 +1838,11 @@ def _backward_pass( cotangents_in: tuple[jnp.ndarray, ...], _j_rules: bool, _s_rules: bool, - _fwd: Optional[bool] -) -> Union[list[list[Union[jnp.ndarray, Zero]]], - list[list[tuple[jnp.ndarray, rules.Structure]]]]: + _fwd: bool | None, +) -> list[ + list[jnp.ndarray | Zero] | + list[list[tuple[jnp.ndarray, rules.Structure]]] +]: """Similar to and adapted from `jax.interpreters.ad.backward_pass`. Traverses the computational graph in the same order as the above, but collects @@ -1856,7 +1860,7 @@ def _backward_pass( the NTK. """ - def read_cotangent(v: Var) -> Union[jnp.ndarray, Zero]: + def read_cotangent(v: Var) -> jnp.ndarray | Zero: return ct_env.pop(v, Zero(v.aval)) primal_env: dict[Var, jnp.ndarray] = {} @@ -2027,9 +2031,9 @@ def _backprop_step( eqn: JaxprEqn, primal_env: dict[Var, jnp.ndarray], ct_env: dict[Var, jnp.ndarray], - read_cotangent: Callable[[Var], Union[jnp.ndarray, Zero]], - do_write_cotangents: bool = True -) -> tuple[Union[jnp.ndarray, Zero], list[Union[jnp.ndarray, UndefinedPrimal]]]: + read_cotangent: Callable[[Var], jnp.ndarray | Zero], + do_write_cotangents: bool = True, +) -> tuple[jnp.ndarray | Zero, list[jnp.ndarray | UndefinedPrimal]]: """Adapted from `jax.interpreters.ad`.""" invals = map(functools.partial(_read_primal, primal_env), eqn.invars) cts_in = map(read_cotangent, eqn.outvars) @@ -2053,7 +2057,7 @@ def _backprop_step( def _trim_cotangents( cts_in: ShapedArray, - structure: rules.Structure + structure: rules.Structure, ) -> ShapedArray: cts_in = _trim_axis( cts_in, @@ -2063,9 +2067,9 @@ def _trim_cotangents( def _trim_invals( - invals: list[Union[jnp.ndarray, UndefinedPrimal]], + invals: list[jnp.ndarray | UndefinedPrimal], structure: rules.Structure, -) -> list[Union[jnp.ndarray, UndefinedPrimal]]: +) -> list[jnp.ndarray | UndefinedPrimal]: trimmed_invals = list(invals) for i in structure.in_trace_idxs: @@ -2086,14 +2090,14 @@ def _trim_invals( if isinstance(trimmed_invals[i], UndefinedPrimal): trimmed_invals[i] = _trim_axis(trimmed_invals[i], in_d) - return trimmed_invals # pytype: disable=bad-return-type # jax-ndarray + return trimmed_invals def _trim_eqn( eqn: JaxprEqn, idx: int, - trimmed_invals: list[Union[jnp.ndarray, UndefinedPrimal]], - trimmed_cts_in: ShapedArray + trimmed_invals: list[jnp.ndarray | UndefinedPrimal], + trimmed_cts_in: ShapedArray, ) -> JaxprEqn: if eqn.primitive in rules.EQN_PARAMS_RULES: # Copy the equation parameters to modify. @@ -2103,7 +2107,7 @@ def _trim_eqn( params=dict(eqn.params), idx=idx, trimmed_invals=trimmed_invals_e, - trimmed_cts_in=trimmed_cts_in + trimmed_cts_in=trimmed_cts_in, ) eqn = eqn.replace(params=params) @@ -2111,9 +2115,9 @@ def _trim_eqn( def _trim_axis( - x: Union[UndefinedPrimal, ShapedArray, jnp.ndarray], - axis: Union[int, tuple[int, ...]], -) -> Union[UndefinedPrimal, ShapedArray]: + x: UndefinedPrimal | ShapedArray | jnp.ndarray, + axis: int | tuple[int, ...], +) -> UndefinedPrimal | ShapedArray: """Trim `axis` of `x` to be of length `1`. `x` is only used for shape.""" if isinstance(axis, int): axis = (axis,) @@ -2129,10 +2133,10 @@ def _trim_axis( def _eqn_jvp_fn( - eqn: Optional[JaxprEqn], + eqn: JaxprEqn | None, idx: int, tangents: jnp.ndarray, - *invals + *invals, ) -> jnp.ndarray: """Perform a JVP for `eqn`.""" if eqn is None: @@ -2166,9 +2170,9 @@ def _eqn_jvp_fn( def _eqn_vjp_fn( - eqn: Optional[JaxprEqn], + eqn: JaxprEqn | None, cts_in: jnp.ndarray, - *invals + *invals, ) -> tuple[jnp.ndarray, ...]: """Perform a VJP for `eqn`. Adapted from `jax.interpreters.ad`.""" if eqn is None: @@ -2195,13 +2199,13 @@ def _eqn_vjp_fn( def _get_jacobian( - eqn: Optional[JaxprEqn], + eqn: JaxprEqn | None, cts_in: ShapedArray, - invals: list[Union[jnp.ndarray, UndefinedPrimal]], + invals: list[jnp.ndarray | UndefinedPrimal], idx: int, _j_rules: bool, - _fwd: Optional[bool], -) -> Union[jnp.ndarray, Zero]: + _fwd: bool | None, +) -> jnp.ndarray | Zero: """Get the (structured) `eqn` output Jacobian wrt `eqn.invars[idx]`.""" if eqn is None: primitive = None @@ -2222,7 +2226,7 @@ def _get_jacobian( else: # Vanilla Jacobian evaluation. - if _get_fwd(_fwd, cts_in_shape, inval_shape): # pytype: disable=wrong-arg-types # always-use-return-annotations + if _get_fwd(_fwd, cts_in_shape, inval_shape): # Forward mode. out_axes = -1 inputs = invals[idx].aval @@ -2244,7 +2248,7 @@ def jac_fn(cotangents): else: dy_dw = dy_dw.reshape(dy_dw_shape) - dy_dw_shape_ = dy_dw.aval.shape if isinstance(dy_dw, Zero) else dy_dw.shape # pytype:disable=attribute-error + dy_dw_shape_ = dy_dw.aval.shape if isinstance(dy_dw, Zero) else dy_dw.shape assert dy_dw_shape_ == dy_dw_shape, (dy_dw_shape_, dy_dw_shape) return dy_dw @@ -2253,7 +2257,7 @@ def _write_cotangent( prim: core.Primitive, ct_env: dict[Var, jnp.ndarray], v: Var, - ct: Union[jnp.ndarray, Zero] + ct: jnp.ndarray | Zero, ): """Adapted from `jax.interpreters.ad`.""" assert ct is not Zero, (prim, v.aval) @@ -2274,8 +2278,8 @@ def _write_cotangent( def _read_primal( env: dict[Var, jnp.ndarray], - v: Union[Var, Literal], -) -> Union[jnp.ndarray, UndefinedPrimal]: + v: Var | Literal, +) -> jnp.ndarray | UndefinedPrimal: if type(v) is Literal: return v.val @@ -2289,16 +2293,16 @@ def _read_primal( def _write_primal( env: dict[Var, jnp.ndarray], v: Var, - val: Union[jnp.ndarray, UndefinedPrimal] + val: jnp.ndarray | UndefinedPrimal, ): if not ad.is_undefined_primal(val): - env[v] = val # pytype: disable=container-type-mismatch # jax-ndarray + env[v] = val def _get_fwd( - _fwd: Optional[bool], + _fwd: bool | None, cts_in_shape: tuple[int, ...], - inval_shape: tuple[int, ...] + inval_shape: tuple[int, ...], ) -> bool: if _fwd is None: out_size = np.prod(cts_in_shape) @@ -2328,7 +2332,7 @@ def _std_basis(pytree: PyTree) -> PyTree: def _unravel_array_into_pytree( pytree: PyTree, axis: int, - arr: jnp.ndarray + arr: jnp.ndarray, ) -> PyTree: """Similar to `jax.api._unravel_array_into_pytree` without host-side ops.""" leaves, treedef = tree_flatten(pytree) @@ -2343,7 +2347,7 @@ def _unravel_array_into_pytree( def _get_res_batch_dims( contracting_dims: Iterable[int], - batch_dims: Iterable[int] + batch_dims: Iterable[int], ) -> list[int]: res_batch_dims = [2 * b - i for i, b in enumerate(batch_dims)] for i, b in enumerate(batch_dims): @@ -2358,7 +2362,7 @@ def _dot_general( rhs: jnp.ndarray, contracting_dims: Axes, batch_dims: Axes, - precision=None + precision=None, ) -> jnp.ndarray: """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims. diff --git a/neural_tangents/_src/monte_carlo.py b/neural_tangents/_src/monte_carlo.py index c933980b..c1e4a57a 100644 --- a/neural_tangents/_src/monte_carlo.py +++ b/neural_tangents/_src/monte_carlo.py @@ -28,12 +28,11 @@ from functools import partial import operator -from typing import Generator, Iterable, Optional, Union +from typing import Generator, Iterable import jax from jax import random import jax.numpy as jnp -from jax.tree_util import tree_map from .batching import batch @@ -61,20 +60,23 @@ def _sample_once_kernel_fn( init_fn: InitFn, batch_size: int = 0, device_count: int = -1, - store_on_device: bool = True + store_on_device: bool = True, ): - @partial(batch, - batch_size=batch_size, - device_count=device_count, - store_on_device=store_on_device) + @partial( + batch, + batch_size=batch_size, + device_count=device_count, + store_on_device=store_on_device, + ) def kernel_fn_sample_once( x1: NTTree[jnp.ndarray], - x2: Optional[NTTree[jnp.ndarray]], + x2: NTTree[jnp.ndarray] | None, key: jax.Array, get: Get, - **apply_fn_kwargs): + **apply_fn_kwargs, + ): init_key, dropout_key = random.split(key, 2) - shape = tree_map(lambda x: x.shape, x1) + shape = jax.tree.map(lambda x: x.shape, x1) _, params = init_fn(init_key, shape) return kernel_fn(x1, x2, get, params, rng=dropout_key, **apply_fn_kwargs) return kernel_fn_sample_once @@ -84,15 +86,17 @@ def _sample_many_kernel_fn( kernel_fn_sample_once, key: jax.Array, n_samples: set[int], - get_generator: bool): + get_generator: bool, +): def normalize(sample: PyTree, n: int) -> PyTree: - return tree_map(lambda sample: sample / n, sample) + return jax.tree.map(lambda sample: sample / n, sample) def get_samples( x1: NTTree[jnp.ndarray], - x2: Optional[NTTree[jnp.ndarray]], + x2: NTTree[jnp.ndarray] | None, get: Get, - **apply_fn_kwargs): + **apply_fn_kwargs, + ): _key = key ker_sampled = None for n in range(1, max(n_samples) + 1): @@ -101,7 +105,7 @@ def get_samples( if ker_sampled is None: ker_sampled = one_sample else: - ker_sampled = tree_map(operator.add, ker_sampled, one_sample) + ker_sampled = jax.tree.map(operator.add, ker_sampled, one_sample) yield n, ker_sampled if get_generator: @@ -109,9 +113,9 @@ def get_samples( def get_sampled_kernel( x1: jnp.ndarray, x2: jnp.ndarray, - get: Optional[Get] = None, - **apply_fn_kwargs - ) -> Generator[Union[jnp.ndarray, tuple[jnp.ndarray, ...]], None, None]: + get: Get = None, + **apply_fn_kwargs, + ) -> Generator[jnp.ndarray | tuple[jnp.ndarray, ...], None, None]: for n, sample in get_samples(x1, x2, get, **apply_fn_kwargs): if n in n_samples: yield normalize(sample, n) @@ -120,9 +124,9 @@ def get_sampled_kernel( def get_sampled_kernel( x1: jnp.ndarray, x2: jnp.ndarray, - get: Optional[Get] = None, - **apply_fn_kwargs - ) -> Union[jnp.ndarray, tuple[jnp.ndarray, ...]]: + get: Get = None, + **apply_fn_kwargs, + ) -> jnp.ndarray | tuple[jnp.ndarray, ...]: for n, sample in get_samples(x1, x2, get, **apply_fn_kwargs): pass return normalize(sample, n) @@ -134,17 +138,17 @@ def monte_carlo_kernel_fn( init_fn: InitFn, apply_fn: ApplyFn, key: jax.Array, - n_samples: Union[int, Iterable[int]], + n_samples: int | Iterable[int], batch_size: int = 0, device_count: int = -1, store_on_device: bool = True, trace_axes: Axes = (-1,), diagonal_axes: Axes = (), - vmap_axes: Optional[VMapAxes] = None, - implementation: Union[int, NtkImplementation] = DEFAULT_NTK_IMPLEMENTATION, + vmap_axes: VMapAxes = None, + implementation: int | NtkImplementation = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, - _fwd: Optional[bool] = _DEFAULT_NTK_FWD, + _fwd: bool | None = _DEFAULT_NTK_FWD, ) -> MonteCarloKernelFn: r"""Return a Monte Carlo sampler of NTK and NNGP kernels of a given function. @@ -156,12 +160,12 @@ def monte_carlo_kernel_fn( Args: init_fn: a function initializing parameters of the neural network. From - :obj:`jax.example_libraries.stax`: "takes an rng key and an input shape + :obj:`jax.example_libraries.stax`: "takes a rng key and an input shape and returns an `(output_shape, params)` pair". apply_fn: a function computing the output of the neural network. - From :obj:`jax.example_libraries.stax`: "takes params, inputs, and an + From :obj:`jax.example_libraries.stax`: "takes params, inputs, and a rng key and applies the layer". key: @@ -317,7 +321,7 @@ def monte_carlo_kernel_fn( implementation=implementation, _s_rules=_s_rules, _j_rules=_j_rules, - _fwd=_fwd + _fwd=_fwd, ) kernel_fn = empirical_kernel_fn(**kwargs) @@ -327,7 +331,7 @@ def monte_carlo_kernel_fn( init_fn=init_fn, batch_size=batch_size, device_count=device_count, - store_on_device=store_on_device + store_on_device=store_on_device, ) n_samples, get_generator = _canonicalize_n_samples(n_samples) @@ -335,13 +339,14 @@ def monte_carlo_kernel_fn( kernel_fn_sample_once=kernel_fn_sample_once, key=key, n_samples=n_samples, - get_generator=get_generator + get_generator=get_generator, ) return kernel_fn def _canonicalize_n_samples( - n_samples: Union[int, Iterable[int]]) -> tuple[set[int], bool]: + n_samples: int | Iterable[int], +) -> tuple[set[int], bool]: get_generator = True if isinstance(n_samples, int): get_generator = False diff --git a/neural_tangents/_src/predict.py b/neural_tangents/_src/predict.py index d10f0950..9e12f257 100644 --- a/neural_tangents/_src/predict.py +++ b/neural_tangents/_src/predict.py @@ -29,7 +29,7 @@ import collections from functools import lru_cache -from typing import Any, Callable, Generator, Iterable, NamedTuple, Optional, Protocol, Union +from typing import Any, Callable, Generator, Iterable, NamedTuple, Protocol import jax from jax import grad @@ -37,7 +37,6 @@ import jax.numpy as jnp import jax.scipy as jsp from jax.tree_util import tree_all -from jax.tree_util import tree_map import numpy as np import scipy as sp @@ -51,7 +50,7 @@ PyTree = Any -ArrayOrScalar = Union[None, int, float, jnp.ndarray] +ArrayOrScalar = None | int | float | jnp.ndarray """Alias for optional arrays or scalars.""" @@ -60,11 +59,11 @@ class PredictFn(Protocol): def __call__( self, - t: Optional[ArrayOrScalar] = None, + t: ArrayOrScalar = None, fx_train_0: ArrayOrScalar = 0., - fx_test_0: Optional[ArrayOrScalar] = None, - k_test_train: Optional[jnp.ndarray] = None - ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]: + fx_test_0: ArrayOrScalar = None, + k_test_train: jnp.ndarray | None = None, + ) -> jnp.ndarray | tuple[jnp.ndarray | jnp.ndarray]: ... @@ -74,7 +73,7 @@ def gradient_descent_mse( learning_rate: float = 1., diag_reg: float = 0., diag_reg_absolute_scale: bool = False, - trace_axes: Axes = (-1,) + trace_axes: Axes = (-1,), ) -> PredictFn: r"""Predicts the outcome of function space gradient descent training on MSE. @@ -229,11 +228,11 @@ def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): return predict_fn_finite def predict_fn( - t: Optional[ArrayOrScalar] = None, + t: ArrayOrScalar = None, fx_train_0: ArrayOrScalar = 0., - fx_test_0: Optional[ArrayOrScalar] = None, - k_test_train: Optional[jnp.ndarray] = None - ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]: + fx_test_0: ArrayOrScalar = None, + k_test_train: jnp.ndarray | None = None, + ) -> jnp.ndarray | tuple[jnp.ndarray | jnp.ndarray]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: @@ -295,10 +294,10 @@ class ODEState: qx_test: test set auxiliary state variable (e.g. momentum). """ - fx_train: Optional[jnp.ndarray] = None - fx_test: Optional[jnp.ndarray] = None - qx_train: Optional[jnp.ndarray] = None - qx_test: Optional[jnp.ndarray] = None + fx_train: jnp.ndarray | None = None + fx_test: jnp.ndarray | None = None + qx_train: jnp.ndarray | None = None + qx_test: jnp.ndarray | None = None class PredictFnODE(Protocol): @@ -306,11 +305,11 @@ class PredictFnODE(Protocol): def __call__( self, - t: Optional[ArrayOrScalar] = None, - fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., - fx_test_0: Optional[ArrayOrScalar] = None, - k_test_train: Optional[jnp.ndarray] = None - ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray], ODEState]: + t: ArrayOrScalar = None, + fx_train_or_state_0: ArrayOrScalar | ODEState = 0., + fx_test_0: ArrayOrScalar = None, + k_test_train: jnp.ndarray | None = None, + ) -> jnp.ndarray | tuple[jnp.ndarray | jnp.ndarray] | ODEState: ... @@ -319,8 +318,8 @@ def gradient_descent( k_train_train: jnp.ndarray, y_train: jnp.ndarray, learning_rate: float = 1., - momentum: Optional[float] = None, - trace_axes: Axes = (-1,) + momentum: float | None = None, + trace_axes: Axes = (-1,), ) -> PredictFnODE: r"""Predicts the outcome of function space training using gradient descent. @@ -467,11 +466,11 @@ def dstate_dt(state_t: ODEState, unused_t) -> ODEState: return dstate_dt def predict_fn( - t: Optional[ArrayOrScalar] = None, - fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., - fx_test_0: Optional[ArrayOrScalar] = None, - k_test_train: Optional[jnp.ndarray] = None - ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray], ODEState]: + t: ArrayOrScalar = None, + fx_train_or_state_0: ArrayOrScalar | ODEState = 0., + fx_test_0: ArrayOrScalar = None, + k_test_train: jnp.ndarray | None = None, + ) -> jnp.ndarray | tuple[jnp.ndarray | jnp.ndarray] | ODEState: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: @@ -530,7 +529,7 @@ def predict_fn( # Remove the added `t0`. trim = lambda x: x[1:].reshape(t_shape + x.shape[1:]) - trim_tree = lambda tree: tree_map(trim, tree) + trim_tree = lambda tree: jax.tree.map(trim, tree) state_t = trim_tree(state_t) # `ODEState` -> `ODEState` @@ -568,7 +567,8 @@ def gp_inference( y_train: jnp.ndarray, diag_reg: float = 0., diag_reg_absolute_scale: bool = False, - trace_axes: Axes = (-1,)): + trace_axes: Axes = (-1,), +): r"""Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP. NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact @@ -636,10 +636,10 @@ def k_inv_y(g: str): @utils.get_namedtuple('Gaussians') def predict_fn( - get: Optional[Get] = None, + get: Get = None, k_test_train=None, - k_test_test=None - ) -> dict[str, Union[jnp.ndarray, Gaussian]]: + k_test_test=None, + ) -> dict[str, jnp.ndarray | Gaussian]: """`test`-set posterior given respective covariance matrices. Args: @@ -758,7 +758,7 @@ def gradient_descent_mse_ensemble( diag_reg: float = 0.0, diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1,), - **kernel_fn_train_train_kwargs + **kernel_fn_train_train_kwargs, ): r"""Predicts the gaussian embedding induced by gradient descent on MSE loss. @@ -862,7 +862,7 @@ def get_k_train_train(get: tuple[str, ...]) -> _Kernel: if not any(g in k_dd_cache for g in get): k_dd_cache.update( kernel_fn(x_train, None, get, - **kernel_fn_train_train_kwargs)._asdict()) # pytype: disable=attribute-error # jax-ndarray + **kernel_fn_train_train_kwargs)._asdict()) else: for g in get: if g not in k_dd_cache: @@ -889,9 +889,11 @@ def predict_inf(get: Get): return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale, trace_axes) - def get_kernels(get: Get, x_test: Optional[jnp.ndarray], - compute_cov: bool, - **kernel_fn_test_test_kwargs): + def get_kernels( + get: Get, x_test: jnp.ndarray | None, + compute_cov: bool, + **kernel_fn_test_test_kwargs, + ): get = _get_dependency(get, compute_cov) k_dd = get_k_train_train(get) if x_test is None: @@ -902,7 +904,7 @@ def get_kernels(get: Get, x_test: Optional[jnp.ndarray], args_test, _ = utils.split_kwargs(kernel_fn_test_test_kwargs, x_test) def is_array(x): - return tree_all(tree_map( + return tree_all(jax.tree.map( lambda x: isinstance(x, (np.ndarray, jnp.ndarray)), x)) kwargs_td = dict(kernel_fn_train_train_kwargs) @@ -943,11 +945,11 @@ def is_array(x): @utils.get_namedtuple('Gaussians') def predict_fn( - t: Optional[ArrayOrScalar] = None, - x_test: Optional[jnp.ndarray] = None, - get: Optional[Get] = None, + t: ArrayOrScalar = None, + x_test: jnp.ndarray | None = None, + get: Get = None, compute_cov: bool = False, - **kernel_fn_test_test_kwargs + **kernel_fn_test_test_kwargs, ) -> dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. @@ -1095,16 +1097,16 @@ def reshape_cov(cov): else: out[g] = mean - return out # pytype: disable=bad-return-type # jnp-type + return out return predict_fn def max_learning_rate( ntk_train_train: jnp.ndarray, - y_train_size: Optional[int] = None, + y_train_size: int | None = None, momentum=0., - eps: float = 1e-12 + eps: float = 1e-12, ) -> float: r"""Computes the maximal feasible learning rate for infinite width NNs. @@ -1138,12 +1140,13 @@ def max_learning_rate( The maximal feasible learning rate for infinite width NNs. """ ntk_train_train = utils.make_2d(ntk_train_train) - factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size # pytype: disable=attribute-error # jax-ndarray + factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size if _is_on_cpu(ntk_train_train): - max_eva = sp.linalg.eigvalsh(ntk_train_train, - eigvals=(ntk_train_train.shape[0] - 1, # pytype: disable=attribute-error # jax-ndarray - ntk_train_train.shape[0] - 1))[-1] # pytype: disable=attribute-error # jax-ndarray + max_eva = sp.linalg.eigvalsh( + ntk_train_train, + subset_by_index=(ntk_train_train.shape[0] - 1,) * 2, + )[0] else: max_eva = jnp.linalg.eigvalsh(ntk_train_train)[-1] lr = 2 * (1 + momentum) * factor / (max_eva + eps) @@ -1181,7 +1184,7 @@ def _get_fns_in_eigenbasis( k_train_train: jnp.ndarray, diag_reg: float, diag_reg_absolute_scale: bool, - fns: Iterable[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] + fns: Iterable[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]], ) -> Generator[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], None, None]: """Build functions of a matrix in its eigenbasis. @@ -1224,7 +1227,7 @@ def new_fn(y_train, t): def _add_diagonal_regularizer( A: jnp.ndarray, diag_reg: float, - diag_reg_absolute_scale: bool + diag_reg_absolute_scale: bool, ) -> jnp.ndarray: dimension = A.shape[0] if not diag_reg_absolute_scale: @@ -1236,7 +1239,7 @@ def _get_cho_solve( A: jnp.ndarray, diag_reg: float, diag_reg_absolute_scale: bool, - lower: bool = False + lower: bool = False, ) -> Callable[[jnp.ndarray, Axes], jnp.ndarray]: x_non_channel_shape = A.shape[1::2] A = utils.make_2d(A) @@ -1261,7 +1264,7 @@ def cho_solve(b: jnp.ndarray, b_axes: Axes) -> jnp.ndarray: def _get_fx_test_shape( y_train: jnp.ndarray, k_test_train: jnp.ndarray, - y_axes: Axes + y_axes: Axes, ) -> tuple[int, ...]: if k_test_train is None: return y_train.shape @@ -1295,9 +1298,9 @@ def _inv_expm1_fn(evals: jnp.ndarray, t: jnp.ndarray): def _check_inputs( - fx_train_or_state_0: Union[ArrayOrScalar, ODEState], + fx_train_or_state_0: ArrayOrScalar | ODEState, fx_test_0: ArrayOrScalar, - k_test_train: Optional[jnp.ndarray] + k_test_train: jnp.ndarray | None, ): if isinstance(fx_train_or_state_0, ODEState): if fx_test_0 is not None: @@ -1349,7 +1352,7 @@ def _get_attr(k, g: str) -> jnp.ndarray: def _is_on_cpu(x: PyTree) -> bool: def _arr_is_on_cpu(x: jnp.ndarray) -> bool: - # TODO(romann): revisit when https://github.com/google/jax/issues/1431 and + # TODO: revisit when https://github.com/google/jax/issues/1431 and # https://github.com/google/jax/issues/1432 are fixed. if hasattr(x, 'addressable_shards'): # device_buffer is deprecated, so try addressable_shards first. @@ -1362,4 +1365,4 @@ def _arr_is_on_cpu(x: jnp.ndarray) -> bool: raise NotImplementedError(type(x)) - return tree_all(tree_map(_arr_is_on_cpu, x)) + return tree_all(jax.tree.map(_arr_is_on_cpu, x)) diff --git a/neural_tangents/_src/stax/branching.py b/neural_tangents/_src/stax/branching.py index 72fbecbe..7079a762 100644 --- a/neural_tangents/_src/stax/branching.py +++ b/neural_tangents/_src/stax/branching.py @@ -19,7 +19,7 @@ """ import functools -from typing import Callable, Iterable, Optional, Sequence +from typing import Callable, Iterable, Sequence import warnings from jax import numpy as jnp @@ -109,7 +109,7 @@ def kernel_fn(ks: Kernels, **kwargs) -> Kernel: channel_axis=ks[0].channel_axis, mask1=None, mask2=None, - ) # pytype:disable=wrong-keyword-args + ) def mask_fn(mask, input_shape): return _sum_masks(mask) @@ -178,7 +178,7 @@ def _mats_prod(nngps, ntks): channel_axis=ks[0].channel_axis, mask1=None, mask2=None, - ) # pytype:disable=wrong-keyword-args + ) def mask_fn(mask, input_shape): return _sum_masks(mask) @@ -235,8 +235,8 @@ def kernel_fn(ks: Kernels, **kwargs) -> Kernel: 'for the case if all input layers guaranteed to be mean-zero ' 'Gaussian, i.e. having all `is_gaussian` set to `True`.') else: - # TODO(romann): allow nonlinearity after channelwise concatenation. - # TODO(romann): support concatenating different channelwise masks. + # TODO: allow nonlinearity after channelwise concatenation. + # TODO: support concatenating different channelwise masks. is_gaussian = False if _axis == batch_axis: @@ -274,22 +274,24 @@ def kernel_fn(ks: Kernels, **kwargs) -> Kernel: ntk = _concat_kernels([k.ntk for k in ks], _axis, False, diagonal_spatial, widths) - return Kernel(cov1=cov1, - cov2=cov2, - nngp=nngp, - ntk=ntk, - x1_is_x2=ks[0].x1_is_x2, - is_gaussian=is_gaussian, - is_reversed=is_reversed, - is_input=ks[0].is_input, - diagonal_batch=diagonal_batch, - diagonal_spatial=diagonal_spatial, - shape1=None, - shape2=None, - batch_axis=batch_axis, - channel_axis=channel_axis, - mask1=None, - mask2=None) # pytype:disable=wrong-keyword-args + return Kernel( + cov1=cov1, + cov2=cov2, + nngp=nngp, + ntk=ntk, + x1_is_x2=ks[0].x1_is_x2, + is_gaussian=is_gaussian, + is_reversed=is_reversed, + is_input=ks[0].is_input, + diagonal_batch=diagonal_batch, + diagonal_spatial=diagonal_spatial, + shape1=None, + shape2=None, + batch_axis=batch_axis, + channel_axis=channel_axis, + mask1=None, + mask2=None, + ) def mask_fn(mask, input_shape): return _concat_masks(mask, input_shape, axis) @@ -304,7 +306,7 @@ def _map_tuples(fn: Callable, tuples: Iterable[tuple]) -> tuple: return tuple(map(fn, zip(*(t for t in tuples)))) -def _sum_masks(masks: list[Optional[jnp.ndarray]]) -> Optional[jnp.ndarray]: +def _sum_masks(masks: list[jnp.ndarray]) -> jnp.ndarray | None: def add_two_masks(mask1, mask2): if mask1 is None: return mask2 @@ -319,10 +321,10 @@ def add_two_masks(mask1, mask2): def _concat_masks( - masks: list[Optional[jnp.ndarray]], + masks: list[jnp.ndarray] | None, input_shapes: Sequence[Sequence[int]], - axis: int -) -> Optional[jnp.ndarray]: + axis: int, +) -> jnp.ndarray | None: """Returns a mask which is a concatenation of `masks`. Since elements of `masks` can have any shapes broadcastable to respective @@ -366,7 +368,7 @@ def _concat_masks( m, max_shape[:axis] + m.shape[axis: axis + 1] + max_shape[axis + 1:]) if m is not None - else jnp.zeros_like(max_shapes[i], dtype=jnp.bool_)) # pytype: disable=wrong-arg-types # jnp-type + else jnp.zeros_like(max_shapes[i], dtype=jnp.bool_)) for i, m in enumerate(masks) ] @@ -412,12 +414,12 @@ def _preprocess_kernels_for_fan_in(ks: Kernels) -> tuple[list[Kernel], bool]: def _concat_kernels( - mats: Sequence[Optional[jnp.ndarray]], + mats: Sequence[jnp.ndarray] | None, axis: int, diagonal_batch: bool, diagonal_spatial: bool, - widths: Sequence[int] -) -> Optional[jnp.ndarray]: + widths: Sequence[int], +) -> jnp.ndarray | None: """Compute the covariance of concatenated activations with given covariances. Args: diff --git a/neural_tangents/_src/stax/combinators.py b/neural_tangents/_src/stax/combinators.py index 1565b164..37b7c940 100644 --- a/neural_tangents/_src/stax/combinators.py +++ b/neural_tangents/_src/stax/combinators.py @@ -203,7 +203,7 @@ def kernel_fn(ks: NTTrees[Kernel], **kwargs) -> NTTrees[Kernel]: def _get_input_req_attr( kernel_fns: list[LayerKernelFn], - fold: Callable[[Diagonal, Diagonal], Diagonal] + fold: Callable[[Diagonal, Diagonal], Diagonal], ) -> dict[str, Any]: """Gets requirements of the combined layer based on individual requirements. diff --git a/neural_tangents/_src/stax/elementwise.py b/neural_tangents/_src/stax/elementwise.py index 58a05a37..245ae831 100644 --- a/neural_tangents/_src/stax/elementwise.py +++ b/neural_tangents/_src/stax/elementwise.py @@ -20,7 +20,7 @@ import functools import operator as op -from typing import Callable, Optional, Sequence +from typing import Callable, Sequence import warnings import jax @@ -49,7 +49,7 @@ def Erf( a: float = 1., b: float = 1., - c: float = 0. + c: float = 0., ) -> InternalLayer: """Affine transform of `Erf` nonlinearity, i.e. `a * Erf(b * x) + c`. @@ -83,8 +83,8 @@ def kernel_fn(k: Kernel) -> Kernel: def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: square_root = _sqrt(prod - 4 * nngp**2) nngp = factor * jnp.arctan2(2 * nngp, square_root) @@ -153,8 +153,8 @@ def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, sum_: jnp.ndarray, - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: diff = 4 * (prod - nngp**2) denom = 2 * sum_ + diff + 1 num = sum_ + diff + 2 * nngp @@ -229,8 +229,8 @@ def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, prod_plus_1: jnp.ndarray, - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: delta_squared = prod_plus_1 - nngp**2 delta = _sqrt(delta_squared) angles = jnp.arctan2(nngp, delta) @@ -277,7 +277,7 @@ def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray: def Sin( a: float = 1., b: float = 1., - c: float = 0. + c: float = 0., ) -> InternalLayer: """Affine transform of `Sin` nonlinearity, i.e. `a sin(b*x + c)`. @@ -331,7 +331,7 @@ def nngp_fn_diag(nngp): def Cos( a: float = 1., b: float = 1., - c: float = 0. + c: float = 0., ) -> InternalLayer: """Affine transform of `Cos` nonlinearity, i.e. `a cos(b*x + c)`. @@ -405,7 +405,7 @@ def nngp_fn_diag(nngp): def ABRelu( a: float, b: float, - do_stabilize: bool = False + do_stabilize: bool = False, ) -> InternalLayer: """ABReLU nonlinearity, i.e. `a * min(x, 0) + b * max(x, 0)`. @@ -618,8 +618,8 @@ def kernel_fn(k: Kernel) -> Kernel: def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: det = _sqrt((prod - factor * nngp**2)) if ntk is not None: @@ -652,7 +652,7 @@ def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray: def ExpNormalized( gamma: float = 1, shift: float = -1, - do_clip: bool = False + do_clip: bool = False, ) -> InternalLayer: """Simulates the "Gaussian normalized kernel". @@ -779,8 +779,8 @@ def kernel_fn(k: Kernel) -> Kernel: def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: def nngp_fn(nngp: jnp.ndarray, degree: int) -> jnp.ndarray: if degree == -1: @@ -885,8 +885,8 @@ def df(theta: jnp.ndarray) -> jnp.ndarray: def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: sqrt_prod = _sqrt(prod) coeff = sqrt_prod**degree / (2 * jnp.pi) @@ -944,7 +944,7 @@ def fn(x): def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk - def r(n: Optional[jnp.ndarray], l: int) -> Optional[jnp.ndarray]: + def r(n: jnp.ndarray, l: int) -> jnp.ndarray | None: if n is None: return None @@ -979,8 +979,8 @@ def nngp_ntk_fn( nngp: jnp.ndarray, prod: jnp.ndarray, r_prods: Sequence[jnp.ndarray], - ntk: Optional[jnp.ndarray] = None - ) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]: + ntk: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: ratio = nngp / _sqrt(prod) if ntk is not None: @@ -995,8 +995,10 @@ def nngp_ntk_fn( return nngp, ntk - def nngp_fn_diag(nngp: jnp.ndarray, - r_prods: Sequence[jnp.ndarray]) -> jnp.ndarray: + def nngp_fn_diag( + nngp: jnp.ndarray, + r_prods: Sequence[jnp.ndarray], + ) -> jnp.ndarray: out = jnp.zeros_like(nngp) for l in range(degree): out += r_prods[l] @@ -1022,9 +1024,9 @@ def nngp_fn_diag(nngp: jnp.ndarray, @layer @supports_masking(remask_kernel=True) def Elementwise( - fn: Optional[Callable[[float], float]] = None, - nngp_fn: Optional[Callable[[float, float, float], float]] = None, - d_nngp_fn: Optional[Callable[[float, float, float], float]] = None + fn: Callable[[float], float] | None = None, + nngp_fn: Callable[[float, float, float], float] | None = None, + d_nngp_fn: Callable[[float, float, float], float] | None = None, ) -> InternalLayer: """Elementwise application of `fn` using provided `nngp_fn`. @@ -1036,8 +1038,8 @@ def Elementwise( numerical integration or `nt.monte_carlo.monte_carlo_kernel_fn` to use Monte Carlo sampling. - If your function is implemented separately (e.g. `nt.stax.Relu` etc) it's best - to use the custom implementation, since it uses symbolically simplified + If your function is implemented separately (e.g. `nt.stax.Relu` etc.) it's + best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable. For details, please see "`Fast Neural Kernel Embeddings for General @@ -1134,7 +1136,7 @@ def kernel_fn(k: Kernel) -> Kernel: def ElementwiseNumerical( fn: Callable[[float], float], deg: int, - df: Optional[Callable[[float], float]] = None + df: Callable[[float], float] | None = None, ) -> InternalLayer: """Activation function using numerical integration. @@ -1206,7 +1208,7 @@ def nngp_ntk_fn(nngp, q11, q22, ntk=None): q11, q22 = jnp.expand_dims(q11, xy_axes), jnp.expand_dims(q22, xy_axes) def integrate(f): - fvals = f(_sqrt(2 * q11) * x) * f( # pytype: disable=wrong-arg-types # jnp-type + fvals = f(_sqrt(2 * q11) * x) * f( nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt( 2*(q22 - nngp**2/q11)) * y) return jnp.tensordot(grid, fvals, (xy_axes, xy_axes)) / jnp.pi @@ -1248,9 +1250,9 @@ def nngp_fn_diag(nngp): def _elementwise( - fn: Optional[Callable[[float], float]], + fn: Callable[[float], float] | None, name: str, - kernel_fn: Optional[LayerKernelFn], + kernel_fn: LayerKernelFn | None, ) -> InternalLayer: init_fn = lambda rng, input_shape: (input_shape, ()) @@ -1291,7 +1293,7 @@ def _sqrt_jvp(tol, primals, tangents): @functools.partial(custom_jvp, nondiff_argnums=(2,)) -def _arctan2(x, y, fill_zero: Optional[float] = None): +def _arctan2(x, y, fill_zero: float | None = None): if fill_zero is not None: return jnp.where(jnp.bitwise_and(x == 0., y == 0.), fill_zero, @@ -1314,9 +1316,9 @@ def _vmap_2d( fn: Callable[[float, float, float], float], cov12: jnp.ndarray, var1: jnp.ndarray, - var2: Optional[jnp.ndarray], + var2: jnp.ndarray | None, diagonal_batch: bool, - diagonal_spatial: bool + diagonal_spatial: bool, ) -> jnp.ndarray: """Effectively a "2D vmap" of `fn(cov12, var1, var2)`. diff --git a/neural_tangents/_src/stax/linear.py b/neural_tangents/_src/stax/linear.py index caa05d3b..5572d776 100644 --- a/neural_tangents/_src/stax/linear.py +++ b/neural_tangents/_src/stax/linear.py @@ -18,7 +18,7 @@ import functools import operator as op import string -from typing import Callable, Iterable, Optional, Sequence, Union +from typing import Callable, Iterable, Sequence import warnings import jax @@ -123,10 +123,10 @@ def Identity() -> InternalLayer: @supports_masking(remask_kernel=False) def DotGeneral( *, - lhs: Optional[Union[jnp.ndarray, float]] = None, - rhs: Optional[Union[jnp.ndarray, float]] = None, + lhs: jnp.ndarray | float | None = None, + rhs: jnp.ndarray | float | None = None, dimension_numbers: lax.DotDimensionNumbers = (((), ()), ((), ())), - precision: Optional[lax.Precision] = None, + precision: lax.Precision | None = None, batch_axis: int = 0, channel_axis: int = -1, ) -> InternalLayerMasked: @@ -134,7 +134,7 @@ def DotGeneral( Dot General allows to express any linear transformation on the inputs, including but not limited to matrix multiplication, pooling, convolutions, - permutations, striding, masking etc (but specialized implementations are + permutations, striding, masking etc. (but specialized implementations are typically much more efficient). Returned `apply_fn` is calling @@ -203,7 +203,7 @@ def DotGeneral( batch_axis: batch axis for `inputs`. Defaults to `0`, the leading axis. Can be present in `dimension_numbers`, but contraction along `batch_axis` will not allow - for further layers to be applied afterwards. + for further layers to be applied afterward. channel_axis: channel axis for `inputs`. Defaults to `-1`, the trailing axis. For @@ -258,10 +258,10 @@ def mask_fn(mask, input_shape): @layer @supports_masking(remask_kernel=True) def Aggregate( - aggregate_axis: Optional[Axes] = None, + aggregate_axis: Axes | None = None, batch_axis: int = 0, channel_axis: int = -1, - to_dense: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = lambda p: p, + to_dense: Callable[[jnp.ndarray], jnp.ndarray] | None = lambda p: p, implementation: str = AggregateImplementation.DENSE.value, ) -> InternalLayer: r"""Aggregation operator (graphical neural network). @@ -471,7 +471,7 @@ def get_dimension_numbers(ndim: int) -> lax.DotDimensionNumbers: @functools.partial(vmap, in_axes=(0, None)) def make_indices(index_array, agg_shape): index_array = jnp.moveaxis(index_array, -1, 0) - raveled = jnp.ravel_multi_index(index_array, agg_shape, 'wrap') # pytype: disable=wrong-arg-types # jnp-type + raveled = jnp.ravel_multi_index(index_array, agg_shape, 'wrap') # We mask edges where either sender or receiver is negative. return jnp.where(jnp.all(index_array >= 0, axis=0), raveled, -1) @@ -505,7 +505,7 @@ def get_senders_receivers(pattern, batch_size: int, agg_ndim: int): def apply_fn(params, inputs: jnp.ndarray, *, - pattern: Optional[jnp.ndarray] = None, + pattern: jnp.ndarray | None = None, **kwargs): """Compute the transformed tensors after an aggregation layer. @@ -604,14 +604,17 @@ def pass_messages(s, r, inputs): return out - @requires(batch_axis=batch_axis, - channel_axis=channel_axis, - diagonal_spatial=Diagonal(input=Bool.NO, output=Bool.NO)) - def kernel_fn(k: Kernel, - *, - pattern: tuple[Optional[jnp.ndarray], - Optional[jnp.ndarray]] = (None, None), - **kwargs): + @requires( + batch_axis=batch_axis, + channel_axis=channel_axis, + diagonal_spatial=Diagonal(input=Bool.NO, output=Bool.NO), + ) + def kernel_fn( + k: Kernel, + *, + pattern: tuple[jnp.ndarray | None, jnp.ndarray | None] = (None, None), + **kwargs, + ) -> Kernel: """Compute the transformed kernels after an aggregation kernel layer. Specifically, the `nngp`/`ntk` is a `2N+2`-D tensor of shape @@ -762,7 +765,7 @@ def agg(k, diagonal_batch, s1, r1, s2, r2): def Dense( out_dim: int, W_std: float = 1., - b_std: Optional[float] = None, + b_std: float | None = None, batch_axis: int = 0, channel_axis: int = -1, parameterization: str = 'ntk', @@ -918,12 +921,14 @@ def fc(x): ntk += 1. cov1, nngp, cov2 = map(fc, (cov1, nngp, cov2)) - return k.replace(cov1=cov1, - nngp=nngp, - cov2=cov2, - ntk=ntk, - is_gaussian=True, - is_input=False) + return k.replace( + cov1=cov1, + nngp=nngp, + cov2=cov2, + ntk=ntk, + is_gaussian=True, + is_input=False, + ) def mask_fn(mask, input_shape): return jnp.all(mask, axis=channel_axis, keepdims=True) @@ -936,11 +941,11 @@ def mask_fn(mask, input_shape): def Conv( out_chan: int, filter_shape: Sequence[int], - strides: Optional[Sequence[int]] = None, + strides: Sequence[int] | None = None, padding: str = Padding.VALID.name, W_std: float = 1.0, - b_std: Optional[float] = None, - dimension_numbers: Optional[tuple[str, str, str]] = None, + b_std: float | None = None, + dimension_numbers: tuple[str, str, str] | None = None, parameterization: str = 'ntk', s: tuple[int, int] = (1, 1), ) -> InternalLayerMasked: @@ -996,11 +1001,11 @@ def Conv( def ConvTranspose( out_chan: int, filter_shape: Sequence[int], - strides: Optional[Sequence[int]] = None, + strides: Sequence[int] | None = None, padding: str = Padding.VALID.name, W_std: float = 1.0, - b_std: Optional[float] = None, - dimension_numbers: Optional[tuple[str, str, str]] = None, + b_std: float | None = None, + dimension_numbers: tuple[str, str, str] | None = None, parameterization: str = 'ntk', s: tuple[int, int] = (1, 1), ) -> InternalLayerMasked: @@ -1056,11 +1061,11 @@ def ConvTranspose( def ConvLocal( out_chan: int, filter_shape: Sequence[int], - strides: Optional[Sequence[int]] = None, + strides: Sequence[int] | None = None, padding: str = Padding.VALID.name, W_std: float = 1.0, - b_std: Optional[float] = None, - dimension_numbers: Optional[tuple[str, str, str]] = None, + b_std: float | None = None, + dimension_numbers: tuple[str, str, str] | None = None, parameterization: str = 'ntk', s: tuple[int, int] = (1, 1), ) -> InternalLayerMasked: @@ -1116,11 +1121,11 @@ def ConvLocal( def _Conv( out_chan: int, filter_shape: Sequence[int], - strides: Optional[Sequence[int]], + strides: Sequence[int] | None, padding: str, W_std: float, - b_std: Optional[float], - dimension_numbers: Optional[tuple[str, str, str]], + b_std: float | None, + dimension_numbers: tuple[str, str, str] | None, parameterization: str, s: tuple[int, int], transpose: bool, @@ -1417,7 +1422,7 @@ def conv(lhs, batch_ndim): is_input=False) # Reorder output spatial dimensions if the finite layer does so. - # TODO(romann): make more efficient / lazy. + # TODO: make more efficient / lazy. out_spec_kernel = tuple(c for c in out_spec if c not in ('N', 'C')) in_to_out_permutation = tuple(out_spec_kernel.index(c) for c in input_spec) res = res.transpose(in_to_out_permutation) @@ -1447,7 +1452,7 @@ def mask_fn(mask, input_shape): else: mask = _pool_mask(mask, filter_shape, strides, init_padding, batch_axis, channel_axis) - mask = jnp.transpose(mask, (out_spec.index(c) for c in lhs_spec)) # pytype: disable=wrong-arg-types # jnp-type + mask = jnp.transpose(mask, (out_spec.index(c) for c in lhs_spec)) return mask @@ -1458,7 +1463,7 @@ def mask_fn(mask, input_shape): @supports_masking(remask_kernel=True) def AvgPool( window_shape: Sequence[int], - strides: Optional[Sequence[int]] = None, + strides: Sequence[int] | None = None, padding: str = Padding.VALID.name, normalize_edges: bool = False, batch_axis: int = 0, @@ -1504,7 +1509,7 @@ def AvgPool( @supports_masking(remask_kernel=True) def SumPool( window_shape: Sequence[int], - strides: Optional[Sequence[int]] = None, + strides: Sequence[int] | None = None, padding: str = Padding.VALID.name, batch_axis: int = 0, channel_axis: int = -1, @@ -1542,7 +1547,7 @@ def SumPool( def _Pool( pool_type: _Pooling, window_shape: Sequence[int], - strides: Optional[Sequence[int]], + strides: Sequence[int] | None, padding: str, normalize_edges: bool, batch_axis: int, @@ -1969,12 +1974,12 @@ def GlobalSelfAttention( W_value_std: float = 1.0, W_query_std: float = 1.0, W_out_std: float = 1.0, - b_std: Optional[float] = None, + b_std: float | None = None, attention_mechanism: str = AttentionMechanism.SOFTMAX.name, pos_emb_type: str = PositionalEmbedding.NONE.name, pos_emb_p_norm: float = 2, - pos_emb_decay_fn: Optional[Callable[[float], float]] = None, - n_chan_pos_emb: Optional[int] = None, + pos_emb_decay_fn: Callable[[float], float] | None = None, + n_chan_pos_emb: int | None = None, W_pos_emb_std: float = 1.0, val_pos_emb: bool = False, batch_axis: int = 0, @@ -2107,7 +2112,7 @@ def GlobalSelfAttention( `W_pos_emb_std`, to keep the total output variance fixed. val_pos_emb: - `True` indicates using positional embeddings when computing all of the + `True` indicates using positional embeddings when computing all the keys/queries/values matrices, `False` makes them only used for keys and queries, but not values. Used only if `pos_emb_type != "NONE"`. @@ -2199,7 +2204,7 @@ def init_fn(rng, input_shape): def apply_fn(params: PyTree, inputs: jnp.ndarray, - mask: Optional[jnp.ndarray] = None, + mask: jnp.ndarray | None = None, **kwargs) -> jnp.ndarray: query_matrices, key_matrices, val_matrices, W_out, b, pos_emb = params @@ -2395,8 +2400,7 @@ def _weigh_kernel(mat, G1, G2=None): mat_dims = (0, -1) + mat_dims res_dims = (0, -1) + res_dims - mat = jnp.einsum(G1, G1_dims, mat, mat_dims, G2, G2_dims, res_dims, # pytype: disable=wrong-arg-types # jnp-type - optimize=True) + mat = jnp.einsum(G1, G1_dims, mat, mat_dims, G2, G2_dims, res_dims) return _affine(mat, OV_std, b_std) G1 = _get_weighting(cov1_interp, k.mask1) @@ -2644,7 +2648,7 @@ def kernel_fn_train(k: Kernel, **kwargs): @supports_masking(remask_kernel=True) def ImageResize( shape: Sequence[int], - method: Union[str, jax.image.ResizeMethod], + method: str | jax.image.ResizeMethod, antialias: bool = True, precision: lax.Precision = lax.Precision.HIGHEST, batch_axis: int = 0, @@ -2852,7 +2856,7 @@ def Index( a `slice` object that would result from indexing an array as `x[idx]`. To create this object, use the helper object :obj:`Slice`, i.e. pass `idx=stax.Slice[1:10, :, ::-1]` (which is equivalent to passing an - explicit `idx=(slice(1, 10, None), slice(None), slice(None, None, -1)`. + explicit `idx=(slice(1, 10, None), slice(None), slice(None, None, -1)`). batch_axis: batch axis for `inputs`. Defaults to `0`, the leading axis. @@ -2936,10 +2940,10 @@ def __getitem__(self, idx: utils.SliceType) -> utils.SliceType: def _affine( - mat: Optional[jnp.ndarray], + mat: jnp.ndarray | None, W_std: float, - b_std: Optional[float], -) -> Optional[jnp.ndarray]: + b_std: float | None, +) -> jnp.ndarray | None: """Get covariances of affine outputs if inputs have covariances `nngp`. The output is assumed to be `xW + b`, where `x` is the input, `W` is a matrix @@ -3100,7 +3104,7 @@ def _pool_transpose( def _get_dimension_numbers( n: int, - channels_first: bool = True + channels_first: bool = True, ) -> tuple[str, str, str]: spatial_dims = ''.join(c for c in string.ascii_uppercase if c not in ('N', 'C', 'I', 'O'))[:n] @@ -3113,12 +3117,12 @@ def _get_dimension_numbers( def _conv_kernel_full_spatial_shared( - lhs: Optional[jnp.ndarray], + lhs: jnp.ndarray | None, filter_shape: Sequence[int], strides: Sequence[int], padding: Padding, batch_ndim: int, -) -> Optional[jnp.ndarray]: +) -> jnp.ndarray | None: """Compute covariance of the CNN outputs given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == False`. @@ -3208,12 +3212,12 @@ def get_n_channels(batch_and_channels: int) -> int: def _conv_kernel_full_spatial_unshared( - lhs: Optional[jnp.ndarray], + lhs: jnp.ndarray | None, filter_shape: Sequence[int], strides: Sequence[int], padding: Padding, batch_ndim: int, -) -> Optional[jnp.ndarray]: +) -> jnp.ndarray | None: """Compute covariance of unshared CNN given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == False`. Has the same outputs on the @@ -3268,12 +3272,12 @@ def _conv_kernel_full_spatial_unshared( def _conv_kernel_full_spatial_transpose( - lhs: Optional[jnp.ndarray], + lhs: jnp.ndarray | None, filter_shape: Sequence[int], strides: Sequence[int], padding: Padding, batch_ndim: int, -) -> Optional[jnp.ndarray]: +) -> jnp.ndarray | None: """Compute covariance of the CNN transpose given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == False`. @@ -3379,12 +3383,12 @@ def get_rhs(n_channels: int, filter_size: int) -> jnp.ndarray: def _conv_kernel_diagonal_spatial( - lhs: Optional[jnp.ndarray], + lhs: jnp.ndarray | None, filter_shape: Sequence[int], strides: Sequence[int], padding: Padding, batch_ndim: int, -) -> Optional[jnp.ndarray]: +) -> jnp.ndarray | None: """Compute covariance of the CNN outputs given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == True`. @@ -3438,12 +3442,12 @@ def _conv_kernel_diagonal_spatial( def _conv_kernel_diagonal_spatial_transpose( - lhs: Optional[jnp.ndarray], + lhs: jnp.ndarray | None, filter_shape: Sequence[int], strides: Sequence[int], padding: Padding, batch_ndim: int, -) -> Optional[jnp.ndarray]: +) -> jnp.ndarray | None: """Compute covariance of the CNN transpose given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == True`. @@ -3615,11 +3619,11 @@ def _diag_mul_diagonal_spatial( def _diag_mul( - x: Optional[jnp.ndarray], + x: jnp.ndarray | None, factor: float, diagonal_batch: bool, diagonal_spatial: bool, -) -> Optional[jnp.ndarray]: +) -> jnp.ndarray | None: if x is None: return x @@ -3633,7 +3637,7 @@ def _vmap_2d( fn: Callable[[float, float, float], float], cov12: jnp.ndarray, var1: jnp.ndarray, - var2: Optional[jnp.ndarray], + var2: jnp.ndarray | None, diagonal_batch: bool, diagonal_spatial: bool, ) -> jnp.ndarray: @@ -3793,7 +3797,7 @@ def apply_fun(params, inputs, **kwargs): def _pos_emb_identity(shape: Sequence[int]) -> jnp.ndarray: - size = utils.size_at(shape) # pytype: disable=wrong-arg-types # jax-ndarray + size = utils.size_at(shape) R = jnp.eye(size).reshape(tuple(shape) * 2) R = utils.zip_axes(R) return R @@ -3801,8 +3805,8 @@ def _pos_emb_identity(shape: Sequence[int]) -> jnp.ndarray: def _pos_emb_pdist( shape: Sequence[int], - pos_emb_p_norm: Optional[float], - pos_emb_decay_fn: Optional[Callable[[float], float]], + pos_emb_p_norm: float | None, + pos_emb_decay_fn: Callable[[float], float] | None, ) -> jnp.ndarray: if pos_emb_decay_fn is None: # Identity / one-hot positional embeddings. @@ -3819,16 +3823,16 @@ def _pos_emb_pdist( (1,) * (2 * (ndim - axis - 1))) R += jnp.abs(pd) ** pos_emb_p_norm - R = pos_emb_decay_fn(R) # pytype: disable=wrong-arg-types # jnp-type - return R # pytype: disable=bad-return-type # jax-ndarray + R = pos_emb_decay_fn(R) + return R def _get_all_pos_emb( k: Kernel, pos_emb_type: PositionalEmbedding, pos_emb_p_norm: float, - pos_emb_decay_fn: Optional[Callable[[float], float]], -) -> tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]]: + pos_emb_decay_fn: Callable[[float], float] | None, +) -> tuple[jnp.ndarray | None, jnp.ndarray | None, jnp.ndarray | None]: if pos_emb_type == PositionalEmbedding.NONE: return None, None, None diff --git a/neural_tangents/_src/stax/requirements.py b/neural_tangents/_src/stax/requirements.py index f436f2e4..d6f7708c 100644 --- a/neural_tangents/_src/stax/requirements.py +++ b/neural_tangents/_src/stax/requirements.py @@ -16,7 +16,7 @@ import dataclasses import enum -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence import warnings import frozendict @@ -26,7 +26,6 @@ from jax import numpy as jnp from jax.core import ShapedArray from jax.tree_util import tree_all -from jax.tree_util import tree_map import numpy as np from ..utils import dataclasses as nt_dataclasses @@ -112,7 +111,7 @@ def new_kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: f'`{key} == {v}`.') elif key in ('batch_axis', 'channel_axis'): - ndim = len(k.shape1) # pytype: disable=attribute-error # preserve-union-macros + ndim = len(k.shape1) v_kernel = getattr(k, key) v_pos = v % ndim if v_kernel != v_pos: @@ -199,37 +198,47 @@ def mask_fn(mask, input_shape): return None return _mask_fn(mask, input_shape) - def apply_fn_with_masking(params, inputs, *, - mask_constant=None, **kwargs): - masked_inputs = tree_map( + def apply_fn_with_masking( + params, + inputs, + *, + mask_constant=None, + **kwargs, + ): + masked_inputs = jax.tree.map( lambda x: _get_masked_array(x, mask_constant), inputs, - is_leaf=lambda x: isinstance(x, (jnp.ndarray, MaskedArray))) + is_leaf=lambda x: isinstance(x, (jnp.ndarray, MaskedArray)), + ) is_leaf = lambda x: isinstance(x, MaskedArray) - inputs = tree_map( + inputs = jax.tree.map( lambda x: x.masked_value, masked_inputs, - is_leaf=is_leaf) - mask = tree_map( + is_leaf=is_leaf, + ) + mask = jax.tree.map( lambda x: x.mask, masked_inputs, - is_leaf=is_leaf) + is_leaf=is_leaf, + ) outputs = apply_fn(params, inputs, mask=mask, **kwargs) - outputs_mask = mask_fn(mask, - inputs.shape if isinstance(inputs, jnp.ndarray) - else [i.shape for i in inputs]) + outputs_mask = mask_fn( + mask, + inputs.shape if isinstance(inputs, jnp.ndarray) else + [i.shape for i in inputs], + ) if outputs_mask is None: return outputs - return MaskedArray(outputs, outputs_mask) # pytype:disable=wrong-arg-count + return MaskedArray(outputs, outputs_mask) # pytype: disable=wrong-arg-count def kernel_fn_with_masking(k: NTTree[Kernel], **user_reqs): is_leaf = lambda k: isinstance(k, Kernel) - mask1 = tree_map(lambda k: k.mask1, k, is_leaf=is_leaf) - shape1 = tree_map(lambda k: k.shape1, k, is_leaf=is_leaf) - mask2 = tree_map(lambda k: k.mask2, k, is_leaf=is_leaf) - shape2 = tree_map(lambda k: k.shape2, k, is_leaf=is_leaf) + mask1 = jax.tree.map(lambda k: k.mask1, k, is_leaf=is_leaf) + shape1 = jax.tree.map(lambda k: k.shape1, k, is_leaf=is_leaf) + mask2 = jax.tree.map(lambda k: k.mask2, k, is_leaf=is_leaf) + shape2 = jax.tree.map(lambda k: k.shape2, k, is_leaf=is_leaf) mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2) @@ -240,7 +249,7 @@ def kernel_fn_with_masking(k: NTTree[Kernel], **user_reqs): else: remask_fn = lambda k, m1, m2: k.replace(mask1=m1, mask2=m2) - k = tree_map(remask_fn, k, mask1, mask2, is_leaf=is_leaf) + k = jax.tree.map(remask_fn, k, mask1, mask2, is_leaf=is_leaf) return k if _has_req(kernel_fn): @@ -280,10 +289,10 @@ def unmask_fn(fn: ApplyFn) -> ApplyFn: Function of same signature as `fn`, where the output :class:`MaskedArray` is replaced with the :class:`jax.numpy.ndarray` with masked entries zeroed-out. """ - def unmask(x: Union[MaskedArray, jnp.ndarray]) -> jnp.ndarray: + def unmask(x: MaskedArray | jnp.ndarray) -> jnp.ndarray: if isinstance(x, MaskedArray): x = utils.mask(x.masked_value, x.mask) - return x # pytype: disable=bad-return-type # jax-ndarray + return x def is_leaf(x) -> bool: return isinstance(x, (jnp.ndarray, MaskedArray)) @@ -291,7 +300,7 @@ def is_leaf(x) -> bool: @utils.wraps(fn) def fn_no_mask(*args, **kwargs): out = fn(*args, **kwargs) - out = tree_map(unmask, out, is_leaf=is_leaf) + out = jax.tree.map(unmask, out, is_leaf=is_leaf) return out return fn_no_mask @@ -326,8 +335,8 @@ class MaskedArray: def _get_masked_array( - x: Union[None, jnp.ndarray, ShapedArray, MaskedArray], - mask_constant: Optional[float] = None + x: None | jnp.ndarray | ShapedArray | MaskedArray, + mask_constant: float | None = None, ) -> MaskedArray: """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask. @@ -373,7 +382,7 @@ def _get_masked_array( def get_req( f: Callable, - default: Optional[frozendict.frozendict] = None + default: frozendict.frozendict | None = None, ) -> frozendict.frozendict: return getattr(f, _INPUT_REQ, default) @@ -393,8 +402,8 @@ def _has_req(f: Callable) -> bool: 'batch_axis': 0, 'use_dropout': False, 'channel_axis': -1, - 'mask_constant': None - } + 'mask_constant': None, + }, ) @@ -518,7 +527,7 @@ def __lshift__(self, other: 'Diagonal') -> 'Diagonal': def _cov_diag_batch_diag_spatial( x: jnp.ndarray, batch_axis: int, - channel_axis: int + channel_axis: int, ) -> jnp.ndarray: ret = jnp.sum(x ** 2, axis=channel_axis) new_batch_axis = batch_axis - (1 if batch_axis > channel_axis else 0) @@ -529,7 +538,7 @@ def _cov_diag_batch_diag_spatial( def _cov_diag_batch_full_spatial( x: jnp.ndarray, batch_axis: int, - channel_axis: int + channel_axis: int, ) -> jnp.ndarray: ret = lax.dot_general(x, x, (((channel_axis,), (channel_axis,)), @@ -543,7 +552,7 @@ def _cov_full_batch_full_spatial( x1: jnp.ndarray, x2: jnp.ndarray, batch_axis: int, - channel_axis: int + channel_axis: int, ) -> jnp.ndarray: ret = jnp.tensordot(x1, x2, (channel_axis, channel_axis)) new_batch_axis = batch_axis - (1 if batch_axis > channel_axis else 0) @@ -557,7 +566,7 @@ def _cov_full_batch_diag_spatial( x1: jnp.ndarray, x2: jnp.ndarray, batch_axis: int, - channel_axis: int + channel_axis: int, ) -> jnp.ndarray: diag_axes = tuple(i for i in range(x1.ndim) if i != batch_axis and i != channel_axis) @@ -573,7 +582,7 @@ def _cov_diag_batch( x: jnp.ndarray, diagonal_spatial: bool, batch_axis: int, - channel_axis: int + channel_axis: int, ) -> jnp.ndarray: if diagonal_spatial: ret = _cov_diag_batch_diag_spatial(x, batch_axis, channel_axis) @@ -584,11 +593,11 @@ def _cov_diag_batch( def _cov( x1: jnp.ndarray, - x2: Optional[jnp.ndarray], + x2: jnp.ndarray | None, diagonal_spatial: bool, batch_axis: int, - channel_axis: int -) -> Optional[jnp.ndarray]: + channel_axis: int, +) -> jnp.ndarray: """Computes uncentered covariance (nngp) between two batches of inputs. Args: @@ -640,16 +649,16 @@ def _cov( def _inputs_to_kernel( x1: jnp.ndarray, - x2: Optional[jnp.ndarray], + x2: jnp.ndarray | None, *, diagonal_batch: bool, - diagonal_spatial: Union[bool, Diagonal], + diagonal_spatial: bool | Diagonal, compute_ntk: bool, batch_axis: int, - channel_axis: Optional[int], - mask_constant: Optional[float], + channel_axis: int | None, + mask_constant: float | None, eps: float = 1e-12, - **kwargs + **kwargs, ) -> Kernel: """Transforms (batches of) inputs to a `Kernel`. @@ -761,7 +770,7 @@ def _inputs_to_kernel( diagonal_spatial = bool(diagonal_spatial) if batch_axis != 0: - # TODO(romann): add support or clear error for batching. + # TODO: add support or clear error for batching. warnings.warn(f'!!! Non-leading (!= 0) batch dimension in the ' f'input layer is not supported for batching ' f'kernels, got batch_axis = {batch_axis}. !!!') @@ -790,7 +799,7 @@ def get_x_cov_mask(x): x = _get_masked_array(x, mask_constant) x, mask = x.masked_value, x.mask - # TODO(schsam): Think more about dtype automatic vs manual dtype promotion. + # TODO: think more about dtype automatic vs manual dtype promotion. x = x.astype(jax.dtypes.canonicalize_dtype(jnp.float64)) if diagonal_batch: @@ -804,11 +813,8 @@ def get_x_cov_mask(x): x2, cov2, mask2 = get_x_cov_mask(x2) nngp = _cov(x1, x2, diagonal_spatial, batch_axis, channel_axis) - ntk = jnp.zeros((), nngp.dtype) if compute_ntk else None # pytype: disable=attribute-error # always-use-return-annotations - is_gaussian = False - is_reversed = False + ntk = jnp.zeros((), nngp.dtype) if compute_ntk else None x1_is_x2 = utils.x1_is_x2(x1, x2, eps=eps) - is_input = False return Kernel( cov1=cov1, @@ -816,9 +822,9 @@ def get_x_cov_mask(x): nngp=nngp, ntk=ntk, x1_is_x2=x1_is_x2, - is_gaussian=is_gaussian, - is_reversed=is_reversed, - is_input=is_input, + is_gaussian=False, + is_reversed=False, + is_input=False, diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial, shape1=x1.shape, @@ -827,18 +833,18 @@ def get_x_cov_mask(x): channel_axis=channel_axis, mask1=mask1, mask2=mask2, - ) # pytype:disable=wrong-keyword-args + ) def _propagate_shape( init_fn: InitFn, apply_fn: ApplyFn, shaped: ShapedArray, - **kwargs + **kwargs, ) -> ShapedArray: """Statically, abstractly, evaluate the init_fn to get shape information.""" def init_and_apply(rng, x): - _, params = init_fn(rng, tree_map(lambda x: x.shape, x)) + _, params = init_fn(rng, jax.tree.map(lambda x: x.shape, x)) return apply_fn(params, x, rng=rng, **kwargs) akey = jax.eval_shape(jax.random.PRNGKey, 0) try: @@ -859,15 +865,21 @@ def _set_shapes( apply_fn: ApplyFn, in_kernel: NTTree[Kernel], out_kernel: NTTree[Kernel], - **kwargs + **kwargs, ) -> NTTree[Kernel]: """Apply a kernel_fn to a Kernel propagating side information.""" is_leaf = lambda k: isinstance(k, Kernel) - shape1 = tree_map(lambda k: ShapedArray(k.shape1, k.nngp.dtype), - in_kernel, is_leaf=is_leaf) - shape2 = tree_map(lambda k: ShapedArray(k.shape2, k.nngp.dtype), - in_kernel, is_leaf=is_leaf) + shape1 = jax.tree.map( + lambda k: ShapedArray(k.shape1, k.nngp.dtype), + in_kernel, + is_leaf=is_leaf, + ) + shape2 = jax.tree.map( + lambda k: ShapedArray(k.shape2, k.nngp.dtype), + in_kernel, + is_leaf=is_leaf, + ) kwargs1, kwargs2 = utils.split_kwargs(kwargs) @@ -876,13 +888,13 @@ def _set_shapes( set_shape_fn = lambda k, s1, s2: k.replace(shape1=s1.shape, shape2=s2.shape) - return tree_map(set_shape_fn, out_kernel, shape1, shape2, is_leaf=is_leaf) + return jax.tree.map(set_shape_fn, out_kernel, shape1, shape2, is_leaf=is_leaf) def _fuse_requirements( kernel_fn_reqs, default_reqs, - **user_reqs + **user_reqs, ) -> frozendict.frozendict: # Override static requirements with explicit user-specified requirements, # but only if they are less demanding, raise an error otherwise. @@ -912,7 +924,7 @@ def _fuse_requirements( def _preprocess_kernel_fn( init_fn: InitFn, apply_fn: ApplyFn, - kernel_fn: LayerKernelFn + kernel_fn: LayerKernelFn, ) -> AnalyticKernelFn: """Returns a `kernel_fn` with additional arguments. @@ -926,7 +938,7 @@ def _preprocess_kernel_fn( Returns: A new `kernel_fn` that does the same computation but accepts additional arguments to flexibly specify the required computation, and can be applied - to either a `Kernel' or a pair of `jnp.ndarrray`s. + to either a `Kernel` or a pair of `jnp.ndarrray`s. """ # Set empty requirements if none specified. if not _has_req(kernel_fn): @@ -943,26 +955,27 @@ def kernel_fn_x1(x1, x2, get, **kwargs): compute_ntk = (get is None) or ('ntk' in get) if x2 is None: - x2 = tree_map(lambda x: None, x1) + x2 = jax.tree.map(lambda x: None, x1) def input_fn(x1, x2): return _inputs_to_kernel(x1, x2, compute_ntk=compute_ntk, **reqs) - kernel = tree_map(input_fn, x1, x2) + kernel = jax.tree.map(input_fn, x1, x2) out_kernel = kernel_fn(kernel, **kwargs) return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs) @utils.get_namedtuple('AnalyticKernel') - def kernel_fn_any(x1_or_kernel: Union[NTTree[jnp.ndarray], NTTree[Kernel]], - x2: Optional[NTTree[jnp.ndarray]] = None, - get: Optional[Get] = None, - *, - pattern: Optional[tuple[Optional[jnp.ndarray], - Optional[jnp.ndarray]]] = None, - mask_constant: Optional[float] = None, - diagonal_batch: Optional[bool] = None, - diagonal_spatial: Optional[bool] = None, - **kwargs): + def kernel_fn_any( + x1_or_kernel: NTTree[jnp.ndarray] | NTTree[Kernel], + x2: NTTree[jnp.ndarray] | None = None, + get: Get = None, + *, + pattern: tuple[jnp.ndarray | None, jnp.ndarray | None] | None = None, + mask_constant: float | None = None, + diagonal_batch: bool | None = None, + diagonal_spatial: bool | None = None, + **kwargs, + ): """Returns the `Kernel` resulting from applying `kernel_fn` to given inputs. Args: @@ -1030,7 +1043,7 @@ def is_leaf(x) -> bool: return isinstance(x, (Kernel, jnp.ndarray, np.ndarray)) return tree_all( - tree_map( + jax.tree.map( lambda x: isinstance(x, cls), x, is_leaf=is_leaf) @@ -1055,10 +1068,10 @@ def is_leaf(x) -> bool: def get_diagonal( - cov: Optional[jnp.ndarray], + cov: jnp.ndarray | None, diagonal_batch: bool, - diagonal_spatial: bool -) -> Optional[jnp.ndarray]: + diagonal_spatial: bool, +) -> jnp.ndarray | None: """Extracts the diagonal of `cov` over all (sample, spatial) dimensions. Adapts computation if `cov` already stores only the diagonal along some @@ -1076,13 +1089,13 @@ def get_diagonal( def get_diagonal_outer_prods( cov1: jnp.ndarray, - cov2: Optional[jnp.ndarray], + cov2: jnp.ndarray | None, diagonal_batch: bool, diagonal_spatial: bool, operation: Callable[[float, float], float], axis: Sequence[int] = (), - mask1: Optional[jnp.ndarray] = None, - mask2: Optional[jnp.ndarray] = None + mask1: jnp.ndarray | None = None, + mask2: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Gets outer products of diagonals `cov1, cov1`, `cov1, cov2`, `cov2, cov2`. @@ -1118,15 +1131,15 @@ def get_diagonal_outer_prods( def mean_and_var( - x: Optional[jnp.ndarray], - axis: Optional[Axes] = None, - dtype: Optional[jnp.dtype] = None, - out: Optional[None] = None, + x: jnp.ndarray | None, + axis: Axes | None = None, + dtype: jnp.dtype | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, - mask: Optional[jnp.ndarray] = None, - get_var: bool = False -) -> tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: + mask: jnp.ndarray | None = None, + get_var: bool = False, +) -> tuple[jnp.ndarray | None, jnp.ndarray | None]: """`jnp.mean` and `jnp.var` taking the `mask` information into account.""" var = None if x is None: diff --git a/neural_tangents/_src/utils/kernel.py b/neural_tangents/_src/utils/kernel.py index e541b2b2..afbcd1f2 100644 --- a/neural_tangents/_src/utils/kernel.py +++ b/neural_tangents/_src/utils/kernel.py @@ -15,7 +15,7 @@ """Class with infinite-width NTK and NNGP :class:`jax.numpy.ndarray` fields.""" import operator as op -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Sequence from jax import lax import jax.numpy as jnp @@ -32,7 +32,7 @@ class Kernel: nngp: covariance between the first and second batches (NNGP). A `jnp.ndarray` of shape - `(batch_size_1, batch_size_2, height, [height,], width, [width,], ...))`, + `(batch_size_1, batch_size_2, height, [height,], width, [width,], ...)`, where exact shape depends on `diagonal_spatial`. ntk: @@ -122,10 +122,10 @@ class Kernel: """ nngp: jnp.ndarray - ntk: Optional[jnp.ndarray] + ntk: jnp.ndarray | None cov1: jnp.ndarray - cov2: Optional[jnp.ndarray] + cov2: jnp.ndarray | None x1_is_x2: jnp.ndarray is_gaussian: bool = dataclasses.field(pytree_node=False) @@ -135,14 +135,14 @@ class Kernel: diagonal_batch: bool = dataclasses.field(pytree_node=False) diagonal_spatial: bool = dataclasses.field(pytree_node=False) - shape1: Optional[tuple[int, ...]] = dataclasses.field(pytree_node=False) - shape2: Optional[tuple[int, ...]] = dataclasses.field(pytree_node=False) + shape1: tuple[int, ...] | None = dataclasses.field(pytree_node=False) + shape2: tuple[int, ...] | None = dataclasses.field(pytree_node=False) batch_axis: int = dataclasses.field(pytree_node=False) channel_axis: int = dataclasses.field(pytree_node=False) - mask1: Optional[jnp.ndarray] = None - mask2: Optional[jnp.ndarray] = None + mask1: jnp.ndarray | None = None + mask2: jnp.ndarray | None = None replace = ... # type: Callable[..., 'Kernel'] asdict = ... # type: Callable[[], dict[str, Any]] @@ -188,7 +188,7 @@ def reverse(self) -> 'Kernel': ntk=ntk, is_reversed=not self.is_reversed) - def transpose(self, axes: Optional[Sequence[int]] = None) -> 'Kernel': + def transpose(self, axes: Sequence[int] | None = None) -> 'Kernel': """Permute spatial dimensions of the `Kernel` according to `axes`. Follows @@ -203,8 +203,8 @@ def transpose(self, axes: Optional[Sequence[int]] = None) -> 'Kernel': if axes is None: axes = tuple(range(len(self.shape1) - 2)) - def permute(mat: Optional[jnp.ndarray], - batch_ndim: int) -> Optional[jnp.ndarray]: + def permute(mat: jnp.ndarray | None, + batch_ndim: int) -> jnp.ndarray | None: if mat is not None: _axes = tuple(batch_ndim + a for a in axes) if not self.diagonal_spatial: @@ -223,8 +223,8 @@ def permute(mat: Optional[jnp.ndarray], def mask( self, - mask1: Optional[jnp.ndarray], - mask2: Optional[jnp.ndarray] + mask1: jnp.ndarray | None, + mask2: jnp.ndarray | None ) -> 'Kernel': """Mask all covariance matrices according to `mask1`, `mask2`.""" mask11, mask12, mask22 = self._get_mask_prods(mask1, mask2) @@ -245,11 +245,9 @@ def mask( def _get_mask_prods( self, - mask1: Optional[jnp.ndarray], - mask2: Optional[jnp.ndarray] - ) -> tuple[Optional[jnp.ndarray], - Optional[jnp.ndarray], - Optional[jnp.ndarray]]: + mask1: jnp.ndarray | None, + mask2: jnp.ndarray | None, + ) -> tuple[jnp.ndarray | None, jnp.ndarray | None, jnp.ndarray | None]: """Gets outer products of `mask1, mask1`, `mask1, mask2`, `mask2, mask2`.""" def get_mask_prod(m1, m2, batch_ndim): if m1 is None and m2 is None: @@ -286,8 +284,8 @@ def reshape(m): def dot_general( self, - other1: Optional[jnp.ndarray], - other2: Optional[jnp.ndarray], + other1: jnp.ndarray | None, + other2: jnp.ndarray | None, is_lhs: bool, dimension_numbers: lax.DotDimensionNumbers ) -> 'Kernel': @@ -380,11 +378,11 @@ def get_out_dims(batch_ndim: int) -> list[int]: return mat_non_c_dims[:n_b] + other_non_c_dims + mat_non_c_dims[n_b:] def dot( - mat: Optional[jnp.ndarray], + mat: jnp.ndarray | None, batch_ndim: int, - other1: Optional[jnp.ndarray] = None, - other2: Optional[jnp.ndarray] = None, - ) -> Optional[jnp.ndarray]: + other1: jnp.ndarray | None = None, + other2: jnp.ndarray | None = None, + ) -> jnp.ndarray | None: if mat is None or mat.ndim == 0 or other1 is None and other2 is None: return mat @@ -403,7 +401,7 @@ def dot( other2_dims = get_other_dims(batch_ndim, False) operands += (other2, other2_dims) - return jnp.einsum(*operands, get_out_dims(batch_ndim), optimize=True) # pytype: disable=wrong-arg-types # jnp-type + return jnp.einsum(*operands, get_out_dims(batch_ndim)) cov1 = dot(self.cov1, 1 if self.diagonal_batch else 2, other1, other1) cov2 = dot(self.cov2, 1 if self.diagonal_batch else 2, other2, other2) diff --git a/neural_tangents/_src/utils/rules.py b/neural_tangents/_src/utils/rules.py index b840a1d7..dff247d4 100644 --- a/neural_tangents/_src/utils/rules.py +++ b/neural_tangents/_src/utils/rules.py @@ -15,7 +15,7 @@ """Structured derivatives rules.""" import functools -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import jax from jax import lax @@ -42,7 +42,7 @@ class Structure: """Describes structure present in a primitive derivative dy/dw. - # TODO(romann): make this a python dataclass. + # TODO: make this a python dataclass. Attributes: out_trace: @@ -96,8 +96,8 @@ class Structure: in_trace_idxs: tuple[int, ...] = field(False, default_factory=tuple) out_diagonal: tuple[int, ...] = field(False, default_factory=tuple) - in_diagonal: tuple[tuple[Optional[int], ...], ...] = field( - False, default_factory=tuple) + in_diagonal: tuple[tuple[int | None, ...], ...] = field( +False, default_factory=tuple) out_broadcast: tuple[int, ...] = field(False, default_factory=tuple) out_broadcast_idxs: tuple[int, ...] = field(False, default_factory=tuple) @@ -143,16 +143,16 @@ def __and__(self, other): ) -STRUCTURE_RULES: dict[Optional[Primitive], Callable[..., Structure]] = {} -JACOBIAN_RULES: dict[Optional[Primitive], Callable[..., jnp.ndarray]] = {} -EQN_PARAMS_RULES: dict[Optional[Primitive], Callable[..., dict[str, Any]]] = {} +STRUCTURE_RULES: dict[Primitive | None, Callable[..., Structure]] = {} +JACOBIAN_RULES: dict[Primitive | None, Callable[..., jnp.ndarray]] = {} +EQN_PARAMS_RULES: dict[Primitive | None, Callable[..., dict[str, Any]]] = {} def get_structure( - eqn: Optional[JaxprEqn], - invals: list[Union[ShapedArray, AbstractValue]], + eqn: JaxprEqn | None, + invals: list[ShapedArray | AbstractValue], idx: int, - _s_rules: bool + _s_rules: bool, ) -> Structure: if any(i is AbstractValue for i in invals): raise TypeError(invals) @@ -182,7 +182,7 @@ def get_structure( # No simplification rule found. structure = Structure() - # TODO(romann): can we avoid special-casing `reshape`s? + # TODO: can we avoid special-casing `reshape`s? if primitive == lax.reshape_p: cts_in = ShapedArray(invals[idx].shape, invals[idx].dtype) @@ -211,7 +211,7 @@ def get_structure( def get_structure_cache( jaxpr: Jaxpr, - _s_rules: bool + _s_rules: bool, ) -> dict[Var, Structure]: """Associates a least common structure to each input variable of the `jaxpr`. @@ -261,7 +261,7 @@ def get_structure_cache( def get_id_structure( inval: AbstractValue, - _s_rules: bool + _s_rules: bool, ) -> Structure: if not isinstance(inval, ShapedArray): raise TypeError(inval) @@ -278,7 +278,7 @@ def get_id_structure( def _eye_like(out_shaped: ShapedArray, in_shaped: ShapedArray) -> jnp.ndarray: assert out_shaped.size == in_shaped.size, (out_shaped, in_shaped) eye = jnp.eye(out_shaped.size, dtype=out_shaped.dtype) - eye = eye.reshape(out_shaped.shape + in_shaped.shape) # pytype: disable=unsupported-operands # always-use-return-annotations + eye = eye.reshape(out_shaped.shape + in_shaped.shape) return eye @@ -289,7 +289,7 @@ def _dot_general_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: contracting_dims, batch_dims = eqn.params['dimension_numbers'] self, other = invals[idx], invals[1 if idx == 0 else 0] @@ -318,7 +318,7 @@ def _dot_general_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: contracting_dims, batch_dims = eqn.params['dimension_numbers'] @@ -346,7 +346,7 @@ def _dot_general_j( self_nc_dims = tuple(i for i in range(self.ndim) if i not in self_c_dims) - j = jnp.moveaxis( # pytype: disable=wrong-arg-types # jnp-type + j = jnp.moveaxis( other, other_b_dims + tuple(d[1] for d in sorted(zip(self_c_dims, other_c_dims))), @@ -391,7 +391,7 @@ def _conv_general_dilated_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: if idx != 1: raise NotImplementedError(eqn, idx) @@ -437,7 +437,7 @@ def _conv_general_dilated_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: if idx != 1: raise NotImplementedError(eqn, idx) @@ -522,7 +522,7 @@ def _conv_general_dilated_e( params: dict[str, Any], idx: int, trimmed_invals: list[ShapedArray], - trimmed_cts_in: ShapedArray + trimmed_cts_in: ShapedArray, ) -> dict[str, Any]: lhs, rhs = trimmed_invals dn = params['dimension_numbers'] @@ -546,7 +546,7 @@ def _add_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: inval = invals[idx] @@ -597,11 +597,11 @@ def _add_j( idx: int, invals: list[ShapedArray], cts_in: ShapedArray, - is_sub: bool + is_sub: bool, ) -> jnp.ndarray: j = jnp.eye(utils.size_at(invals[idx]), dtype=invals[idx].dtype) - j = j.reshape(invals[idx].shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations - j = jnp.broadcast_to(j, cts_in.shape + invals[idx].shape) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(invals[idx].shape * 2) + j = jnp.broadcast_to(j, cts_in.shape + invals[idx].shape) if is_sub and idx == 1: j = -j return j @@ -620,7 +620,7 @@ def _mul_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: inval = invals[idx] ndim = inval.ndim @@ -666,23 +666,23 @@ def _mul_s( def _mul_j( eqn: JaxprEqn, idx: int, - invals: list[Union[ShapedArray, jnp.ndarray]], + invals: list[ShapedArray | jnp.ndarray], cts_in: ShapedArray, - is_div: bool + is_div: bool, ) -> jnp.ndarray: if is_div and idx != 0: raise ValueError(eqn, idx) inval = invals[idx] if inval.size == 0: - return jnp.zeros(cts_in.shape + inval.shape, inval.dtype) # pytype: disable=unsupported-operands # always-use-return-annotations + return jnp.zeros(cts_in.shape + inval.shape, inval.dtype) other = invals[1 if idx == 0 else 0] if is_div: other = jnp.ones((), other.dtype) / other if inval.ndim == 0: - return other # pytype: disable=bad-return-type # jax-ndarray + return other if other.ndim == 0: other = jnp.broadcast_to(other, inval.shape) @@ -691,7 +691,7 @@ def _mul_j( j = jnp.broadcast_to(other, cts_in.shape).reshape((-1,)) j = jnp.diag(j) - j = j.reshape(cts_in.shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(cts_in.shape * 2) sum_axes = () for i in range(inval.ndim): @@ -715,7 +715,7 @@ def _concatenate_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: dimension = eqn.params['dimension'] @@ -732,7 +732,7 @@ def _concatenate_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: dimension = eqn.params['dimension'] @@ -749,11 +749,11 @@ def _concatenate_j( inval_i_size = np.prod(inval_i_shape) j = jnp.zeros((inval_i_size, inval.size), inval.dtype) - j = j.reshape(inval_i_shape + inval.shape) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(inval_i_shape + inval.shape) js.append(j) j = lax.concatenate(js, dimension) - j = j.reshape(cts_in.shape + inval.shape) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(cts_in.shape + inval.shape) return j STRUCTURE_RULES[lax.concatenate_p] = _concatenate_s @@ -767,7 +767,7 @@ def _rev_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: dimensions = eqn.params['dimensions'] in_trace = out_trace = tuple(i for i in range(invals[idx].ndim) @@ -785,7 +785,7 @@ def _rev_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: inval = invals[idx] j = _eye_like(cts_in, inval) @@ -800,7 +800,7 @@ def _broadcast_in_dim_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: broadcast_dimensions = eqn.params['broadcast_dimensions'] @@ -823,14 +823,14 @@ def _broadcast_in_dim_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: inval = invals[idx] j = jnp.eye(inval.size, dtype=inval.dtype) - j = j.reshape(inval.shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(inval.shape * 2) j = lax.broadcast_in_dim( j, - cts_in.shape + inval.shape, # pytype: disable=unsupported-operands # always-use-return-annotations + cts_in.shape + inval.shape, broadcast_dimensions=eqn.params['broadcast_dimensions'] + tuple(range(cts_in.ndim, cts_in.ndim + inval.ndim))) return j @@ -839,7 +839,7 @@ def _broadcast_in_dim_e( params: dict[str, Any], idx: int, trimmed_invals: list[ShapedArray], - trimmed_cts_in: ShapedArray + trimmed_cts_in: ShapedArray, ) -> dict[str, Any]: # `broadcast_in_dim` is the only primitive JVP where we need to change # equation parameters in response to tweaking the inputs/cotangents @@ -856,7 +856,7 @@ def _reduce_sum_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: axes = eqn.params['axes'] @@ -875,13 +875,13 @@ def _reduce_sum_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: inval = invals[idx] j = jnp.eye(cts_in.size, dtype=inval.dtype) - j = j.reshape(cts_in.shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(cts_in.shape * 2) j = jnp.expand_dims(j, tuple(a + cts_in.ndim for a in eqn.params['axes'])) - j = jnp.broadcast_to(j, cts_in.shape + inval.shape) # pytype: disable=unsupported-operands # always-use-return-annotations + j = jnp.broadcast_to(j, cts_in.shape + inval.shape) return j STRUCTURE_RULES[lax.reduce_sum_p] = _reduce_sum_s @@ -892,7 +892,7 @@ def _reduce_window_sum_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: out_trace = () for i in range(cts_in.ndim): @@ -917,7 +917,7 @@ def _pad_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: padding_config = eqn.params['padding_config'] @@ -937,13 +937,13 @@ def _pad_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: padding_config = eqn.params['padding_config'] inval = invals[idx] j = jnp.eye(inval.size, dtype=inval.dtype) - j = j.reshape(inval.shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(inval.shape * 2) for _ in range(inval.ndim): padding_config += ((0, 0, 0),) @@ -958,7 +958,7 @@ def _reshape_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: out_trace = tuple(range(invals[idx].ndim)) if eqn.params['dimensions'] is None: @@ -978,23 +978,23 @@ def _reshape_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: inval = invals[idx] j = _eye_like(inval, inval) - j = j.reshape(inval.shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(inval.shape * 2) inval_dims = tuple(i + inval.ndim for i in range(inval.ndim)) if eqn.params['dimensions'] is not None: j = lax.transpose(j, eqn.params['dimensions'] + inval_dims) - j = j.reshape(inval.shape + inval.shape) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(inval.shape + inval.shape) return j def _reshape_e( params: dict[str, Any], idx: int, trimmed_invals: list[ShapedArray], - trimmed_cts_in: ShapedArray + trimmed_cts_in: ShapedArray, ) -> dict[str, Any]: # Hack for more efficient `reshape` structure rule. params['new_sizes'] = trimmed_invals[idx].shape @@ -1006,10 +1006,10 @@ def _reshape_e( def _eye_s( - eqn: Optional[JaxprEqn], + eqn: JaxprEqn | None, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: """Use this for elementwise-linear in `p` primitives `y(p, x)`. @@ -1032,10 +1032,10 @@ def _eye_s( ) def _eye_j( - eqn: Optional[JaxprEqn], + eqn: JaxprEqn | None, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: j = _eye_like(cts_in, invals[idx]) return j @@ -1050,7 +1050,7 @@ def _neg_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: j = _eye_like(cts_in, invals[idx]) return -j @@ -1063,9 +1063,9 @@ def _zeros_like_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: - return jnp.zeros(cts_in.shape + invals[idx].shape, cts_in.dtype) # pytype: disable=unsupported-operands # always-use-return-annotations + return jnp.zeros(cts_in.shape + invals[idx].shape, cts_in.dtype) STRUCTURE_RULES[jax.interpreters.ad.zeros_like_p] = _eye_s JACOBIAN_RULES[jax.interpreters.ad.zeros_like_p] = _zeros_like_j @@ -1075,7 +1075,7 @@ def _transpose_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: in_trace = tuple(range(cts_in.ndim)) out_trace = tuple(eqn.params['permutation'].index(i) for i in in_trace) @@ -1092,15 +1092,15 @@ def _transpose_j( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> jnp.ndarray: j = _eye_like(cts_in, invals[idx]) inval = invals[idx] - j = j.reshape(inval.shape * 2) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(inval.shape * 2) inval_dims = tuple(i + cts_in.ndim for i in range(cts_in.ndim)) j = lax.transpose(j, eqn.params['permutation'] + inval_dims) - j = j.reshape(cts_in.shape + invals[idx].shape) # pytype: disable=unsupported-operands # always-use-return-annotations + j = j.reshape(cts_in.shape + invals[idx].shape) return j STRUCTURE_RULES[lax.transpose_p] = _transpose_s @@ -1111,7 +1111,7 @@ def _squeeze_s( eqn: JaxprEqn, idx: int, invals: list[ShapedArray], - cts_in: ShapedArray + cts_in: ShapedArray, ) -> Structure: out_trace = tuple(range(cts_in.ndim)) in_trace = tuple(i for i in range(invals[idx].ndim) diff --git a/neural_tangents/_src/utils/typing.py b/neural_tangents/_src/utils/typing.py index 3d596cce..74d1fb6d 100644 --- a/neural_tangents/_src/utils/typing.py +++ b/neural_tangents/_src/utils/typing.py @@ -14,7 +14,7 @@ """Common Type Definitions.""" -from typing import Any, Generator, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generator, Protocol, Sequence, TYPE_CHECKING, TypeVar import jax import jax.numpy as jnp @@ -29,7 +29,7 @@ """ -Axes = Union[int, Sequence[int]] +Axes = int | Sequence[int] """Axes specification, can be integers (`axis=-1`) or sequences (`axis=(1, 3)`). """ @@ -37,11 +37,11 @@ T = TypeVar('T') if TYPE_CHECKING: - NTTree = Union[T, list['NTTree[T]'], tuple['NTTree[T]', ...], T] - NTTrees = Union[list['NTTree[T]'], tuple['NTTree[T]', ...]] + NTTree = T | list['NTTree[T]'] | tuple['NTTree[T]', ...] | T + NTTrees = list['NTTree[T]'] | tuple['NTTree[T]', ...] else: # Can't use recursive types with `sphinx-autodoc-typehints`. - NTTree = Union[list[T], tuple[T, ...], T] + NTTree = list[T] | tuple[T, ...] | T """Neural Tangents Tree. Trees of kernels and arrays naturally emerge in certain neural @@ -54,7 +54,7 @@ :class:`~neural_tangents.Kernel` objects. """ - NTTrees = Union[list[T], tuple[T, ...]] + NTTrees = list[T] | tuple[T, ...] """A list or tuple of :class:`NTTree` s. """ @@ -79,7 +79,7 @@ def __call__( self, rng: jax.Array, input_shape: Shapes, - **kwargs + **kwargs, ) -> tuple[Shapes, PyTree]: ... @@ -97,7 +97,7 @@ def __call__( params: PyTree, inputs: NTTree[jnp.ndarray], *args, - **kwargs + **kwargs, ) -> NTTree[jnp.ndarray]: ... @@ -110,16 +110,16 @@ class MaskFn(Protocol): def __call__( self, - mask: Union[jnp.ndarray, Sequence[jnp.ndarray]], + mask: jnp.ndarray | Sequence[jnp.ndarray], input_shape: Shapes, - ) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]: + ) -> jnp.ndarray | Sequence[jnp.ndarray]: ... -KernelOrInput = Union[NTTree[Kernel], NTTree[jnp.ndarray]] +KernelOrInput = NTTree[Kernel] | NTTree[jnp.ndarray] -Get = Union[None, str, tuple[str, ...]] +Get = None | str | tuple[str, ...] class LayerKernelFn(Protocol): @@ -130,10 +130,7 @@ class LayerKernelFn(Protocol): types. """ - def __call__( - self, - k: NTTree[Kernel] - ) -> NTTree[Kernel]: + def __call__(self, k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: ... @@ -150,10 +147,10 @@ class AnalyticKernelFn(Protocol): def __call__( self, x1: KernelOrInput, - x2: Optional[NTTree[jnp.ndarray]] = None, + x2: NTTree[jnp.ndarray] | None = None, get: Get = None, - **kwargs - ) -> Union[NTTree[Kernel], NTTree[jnp.ndarray]]: + **kwargs, + ) -> NTTree[Kernel] | NTTree[jnp.ndarray]: ... @@ -170,10 +167,10 @@ class EmpiricalGetKernelFn(Protocol): def __call__( self, x1: NTTree[jnp.ndarray], - x2: Optional[NTTree[jnp.ndarray]], + x2: NTTree[jnp.ndarray] | None, get: Get, params: PyTree, - **kwargs + **kwargs, ) -> NTTree[jnp.ndarray]: ... @@ -190,9 +187,9 @@ class EmpiricalKernelFn(Protocol): def __call__( self, x1: NTTree[jnp.ndarray], - x2: Optional[NTTree[jnp.ndarray]], + x2: NTTree[jnp.ndarray] | None, params: PyTree, - **kwargs + **kwargs, ) -> NTTree[jnp.ndarray]: ... @@ -207,19 +204,19 @@ class MonteCarloKernelFn(Protocol): def __call__( self, x1: NTTree[jnp.ndarray], - x2: Optional[NTTree[jnp.ndarray]], + x2: NTTree[jnp.ndarray] | None, get: Get = None, - **kwargs - ) -> Union[NTTree[jnp.ndarray], Generator[NTTree[jnp.ndarray], None, None]]: + **kwargs, + ) -> NTTree[jnp.ndarray] | Generator[NTTree[jnp.ndarray], None, None]: ... -KernelFn = Union[ - AnalyticKernelFn, - EmpiricalKernelFn, - EmpiricalGetKernelFn, - MonteCarloKernelFn, -] +KernelFn = ( + AnalyticKernelFn | + EmpiricalKernelFn | + EmpiricalGetKernelFn | + MonteCarloKernelFn +) InternalLayer = tuple[InitFn, ApplyFn, LayerKernelFn] @@ -229,16 +226,16 @@ def __call__( Layer = tuple[InitFn, ApplyFn, AnalyticKernelFn] -Kernels = Union[list[Kernel], tuple[Kernel, ...]] +Kernels = list[Kernel] | tuple[Kernel, ...] """Kernel inputs/outputs of `FanOut`, `FanInSum`, etc. """ -_VMapAxis = Optional[PyTree] +_VMapAxis = PyTree | None """A `PyTree` of integers. """ VMapAxisTriple = tuple[_VMapAxis, _VMapAxis, dict[str, _VMapAxis]] -VMapAxes = Union[_VMapAxis, VMapAxisTriple] +VMapAxes = _VMapAxis | VMapAxisTriple """Specifies `(input, output, kwargs)` axes for `vmap` in empirical NTK. """ diff --git a/neural_tangents/_src/utils/utils.py b/neural_tangents/_src/utils/utils.py index 0284beb6..6dd4614e 100644 --- a/neural_tangents/_src/utils/utils.py +++ b/neural_tangents/_src/utils/utils.py @@ -22,7 +22,7 @@ import inspect import operator import types -from typing import Any, Callable, Iterable, Optional, Sequence, Sized, TypeVar, Union +from typing import Any, Callable, Iterable, Sequence, Sized, TypeVar import warnings import jax @@ -30,14 +30,13 @@ from jax import random import jax.numpy as jnp from jax.tree_util import tree_all -from jax.tree_util import tree_map import numpy as np PyTree = Any -Axes = Union[int, Sequence[int]] +Axes = int | Sequence[int] def is_list_or_tuple(x) -> bool: @@ -46,7 +45,7 @@ def is_list_or_tuple(x) -> bool: return type(x) == list or type(x) == tuple -def is_nt_tree_of(x, dtype: Union[type, tuple[type, ...]]) -> bool: +def is_nt_tree_of(x, dtype: type | tuple[type, ...]) -> bool: if isinstance(x, dtype): return True if not is_list_or_tuple(x): @@ -55,9 +54,9 @@ def is_nt_tree_of(x, dtype: Union[type, tuple[type, ...]]) -> bool: def nt_tree_fn( - nargs: Optional[int] = None, - tree_structure_argnum: Optional[int] = None, - reduce: Callable = lambda x: x + nargs: int | None = None, + tree_structure_argnum: int | None = None, + reduce: Callable = lambda x: x, ): """Convert a function that acts on single inputs to one that acts on trees. @@ -131,9 +130,9 @@ def wrapped_fn(*args, **kwargs): return tree_fn -def all_none(x, attr: Optional[str] = None) -> bool: +def all_none(x, attr: str | None = None) -> bool: get_fn = (lambda x: x) if attr is None else lambda x: getattr(x, attr) - return tree_all(tree_map(lambda x: get_fn(x) is None, x)) + return tree_all(jax.tree.map(lambda x: get_fn(x) is None, x)) def canonicalize_get(get): @@ -141,7 +140,7 @@ def canonicalize_get(get): return True, get if not get: - # NOTE(schsam): It seems slightly nicer to not support the empty-tuple + # NOTE: It seems slightly nicer to not support the empty-tuple # case. Happy to add support later, if there's a use-case. raise ValueError('"get" must be non-empty.') @@ -255,9 +254,9 @@ def canonicalize_output(out): @nt_tree_fn(nargs=2, reduce=lambda x: jnp.all(jnp.array(x))) def x1_is_x2( x1: jnp.ndarray, - x2: Optional[jnp.ndarray] = None, - eps: float = 1e-12 -) -> Union[bool, jnp.ndarray]: + x2: jnp.ndarray | None = None, + eps: float = 1e-12, +) -> bool | jnp.ndarray: if not isinstance(x1, (np.ndarray, jnp.ndarray)): raise TypeError('`x1` must be an ndarray. A {} is found.'.format(type(x1))) @@ -282,7 +281,7 @@ def x1_is_x2( return jnp.all(jnp.abs(diff) < eps) -def _get_ndim(x: Union[int, Sized, jnp.ndarray]) -> int: +def _get_ndim(x: int | Sized | jnp.ndarray) -> int: """Get number of dimensions given number of dimensions / shape / array.""" if hasattr(x, 'ndim'): n = x.ndim @@ -295,7 +294,7 @@ def _get_ndim(x: Union[int, Sized, jnp.ndarray]) -> int: return n -def mod(axis: Axes, x: Union[int, Sized, jnp.ndarray]) -> list[int]: +def mod(axis: Axes, x: int | Sized | jnp.ndarray) -> list[int]: """Makes `axis` non-negative given number of dimensions / shape / array.""" n = _get_ndim(x) if isinstance(axis, int): @@ -305,7 +304,7 @@ def mod(axis: Axes, x: Union[int, Sized, jnp.ndarray]) -> list[int]: def canonicalize_axis( axis: Axes, - x: Union[int, Sized, jnp.ndarray] + x: int | Sized | jnp.ndarray, ) -> list[int]: """Converts axis into a sorted non-negative list. @@ -324,7 +323,7 @@ def canonicalize_axis( def zip_axes( x: jnp.ndarray, start_axis: int = 0, - end_axis: Optional[int] = None + end_axis: int | None = None, ) -> jnp.ndarray: """Zip (interleave) axes starting from `start_axis`. @@ -342,9 +341,11 @@ def zip_axes( return _zip_axes(x, start_axis, end_axis, unzip=False) -def unzip_axes(x: jnp.ndarray, - start_axis: int = 0, - end_axis: Optional[int] = None) -> jnp.ndarray: +def unzip_axes( + x: jnp.ndarray, + start_axis: int = 0, + end_axis: int | None = None, +) -> jnp.ndarray: """Unzip (de-interleave) axes starting from `start_axis`. Changes the shape as follows: @@ -364,8 +365,8 @@ def unzip_axes(x: jnp.ndarray, def _zip_axes( x: jnp.ndarray, start_axis: int = 0, - end_axis: Optional[int] = None, - unzip: bool = False + end_axis: int | None = None, + unzip: bool = False, ) -> jnp.ndarray: """Zip/unzip (interleave/de-interleave) axes starting from `start_axis`. @@ -405,7 +406,7 @@ def _zip_axes( def diagonal_between( x: jnp.ndarray, start_axis: int = 0, - end_axis: Optional[int] = None + end_axis: int | None = None, ) -> jnp.ndarray: """Returns the diagonal along all dimensions between start and end axes.""" if end_axis is None: @@ -464,7 +465,7 @@ def outer_prod(x, y, start_axis, end_axis, prod_op): def reverse_zipped( x: _ArrayOrShape, - start_axis: int = 0 + start_axis: int = 0, ) -> _ArrayOrShape: if x is not None: ndim = _get_ndim(x) @@ -481,17 +482,17 @@ def reverse_zipped( def mask( - x: Optional[jnp.ndarray], - mask_mat: Optional[jnp.ndarray] -) -> Optional[jnp.ndarray]: + x: jnp.ndarray | None, + mask_mat: jnp.ndarray | None, +) -> jnp.ndarray | None: if x is None or mask_mat is None: return x return jnp.where(mask_mat, jnp.zeros((), x.dtype), x) def size_at( - x: Union[_ArrayOrShape, core.ShapedArray], - axes: Optional[Iterable[int]] = None + x: _ArrayOrShape | core.ShapedArray, + axes: Iterable[int] | None = None, ) -> int: if hasattr(x, 'shape'): x = x.shape @@ -506,7 +507,7 @@ def axis_after_dot( axis: int, contracting_dims: Sequence[int], batch_dims: Sequence[int], - lhs_ndim: Optional[int] = None + lhs_ndim: int | None = None, ) -> int: if axis in batch_dims: return batch_dims.index(axis) @@ -521,10 +522,10 @@ def axis_after_dot( def make_2d( - x: Optional[jnp.ndarray], + x: jnp.ndarray | None, start_axis: int = 0, - end_axis: Optional[int] = None -) -> Optional[jnp.ndarray]: + end_axis: int | None = None, +) -> jnp.ndarray | None: """Makes `x` 2D from `start_axis` to `end_axis`, preserving other axes. `x` is assumed to follow the (`X, X, Y, Y, Z, Z`) axes layout. @@ -605,10 +606,10 @@ def split_kwargs(kwargs, x1=None, x2=None): return kwargs1, kwargs2 -_SingleSlice = Union[int, slice, type(Ellipsis)] +_SingleSlice = int | slice | type(Ellipsis) -SliceType = Union[_SingleSlice, tuple[_SingleSlice, ...]] +SliceType = _SingleSlice | tuple[_SingleSlice, ...] """A type to specify a slice of an array. For instance, when indexing `x[1, :, 2:8:3]` a slice tuple @@ -621,8 +622,8 @@ def split_kwargs(kwargs, x1=None, x2=None): def canonicalize_idx( idx: SliceType, - ndim: int -) -> tuple[Union[int, slice], ...]: + ndim: int, +) -> tuple[int | slice, ...]: if idx is Ellipsis or isinstance(idx, (int, slice)): idx = (idx,) + (slice(None),) * (ndim - 1) diff --git a/neural_tangents/experimental/empirical_tf/empirical.py b/neural_tangents/experimental/empirical_tf/empirical.py index 0bd4f8d7..1e2eee03 100644 --- a/neural_tangents/experimental/empirical_tf/empirical.py +++ b/neural_tangents/experimental/empirical_tf/empirical.py @@ -14,7 +14,7 @@ """Experimental prototype of empirical NTK computation in Tensorflow. -This module is applicable to :class:`tf.Module`, :class:`tf.keras.Model`, or +This module is applicable to :class:`tf.Module`, :class:`keras.Model`, or :obj:`tf.function` functions, subject to some conditions (see docstring of :obj:`empirical_ntk_fn_tf`). @@ -22,7 +22,7 @@ Please read the respective docstring for more details. .. warning:: - This module currently appears to have long compile times (but OK runtime), + This module currently appears to have long compilation times (but OK runtime), is prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model. @@ -30,15 +30,16 @@ "`Fast Finite Width Neural Tangent Kernel `_". Example: + >>> import keras >>> import tensorflow as tf - >>> from tensorflow.keras import layers + >>> from keras import layers >>> import neural_tangents as nt >>> # >>> x_train = tf.random.normal((20, 32, 32, 3)) >>> x_test = tf.random.normal((5, 32, 32, 3)) >>> # >>> # A CNN. - >>> f = tf.keras.Sequential() + >>> f = keras.Sequential() >>> f.add(layers.Conv2D(32, (3, 3), activation='relu', >>> input_shape=x_train.shape[1:])) >>> f.add(layers.Conv2D(32, (3, 3), activation='relu')) @@ -49,7 +50,7 @@ >>> f.build((None, *x_train.shape[1:])) >>> _, params = nt.experimental.get_apply_fn_and_params(f) >>> # - >>> # Default setting: reducing over logits (default `trace_axes=(-1,)`; + >>> # Default setting: reducing over logits (default `trace_axes=(-1,)`); >>> # pass `vmap_axes=0` because the network is iid along the batch axis, no >>> # BatchNorm. >>> kernel_fn = nt.experimental.empirical_ntk_fn_tf(f, vmap_axes=0) @@ -66,7 +67,7 @@ >>> k_test_train = kernel_fn(x_test, x_train, params) >>> # >>> # An FCN - >>> f = tf.keras.Sequential() + >>> f = keras.Sequential() >>> f.add(layers.Flatten()) >>> f.add(layers.Dense(1024, activation='relu')) >>> f.add(layers.Dense(1024, activation='relu')) @@ -90,10 +91,11 @@ >>> ntk_train_train_diag = ntk_fn(x_train, None, params) """ -from typing import Callable, Optional, Union +from typing import Callable import warnings from jax.experimental import jax2tf +import keras from neural_tangents._src.empirical import _DEFAULT_NTK_FWD from neural_tangents._src.empirical import _DEFAULT_NTK_J_RULES from neural_tangents._src.empirical import _DEFAULT_NTK_S_RULES @@ -108,19 +110,19 @@ def empirical_ntk_fn_tf( - f: Union[tf.Module, tf.types.experimental.PolymorphicFunction], + f: tf.Module | tf.types.experimental.PolymorphicFunction | keras.Model, trace_axes: Axes = (-1,), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, - implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, + implementation: NtkImplementation | int = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, - _fwd: Optional[bool] = _DEFAULT_NTK_FWD, + _fwd: bool | None = _DEFAULT_NTK_FWD, ) -> Callable[..., PyTree]: r"""Returns a function to draw a single sample the NTK of a given network `f`. This function follows the API of :obj:`neural_tangents.empirical_ntk_fn`, but - is applicable to Tensorflow :class:`tf.Module`, :class:`tf.keras.Model`, or + is applicable to Tensorflow :class:`tf.Module`, :class:`keras.Model`, or :obj:`tf.function`, via a TF->JAX->TF roundtrip using `tf2jax` and `jax2tf`. Docstring below adapted from :obj:`neural_tangents.empirical_ntk_fn`. @@ -132,23 +134,22 @@ def empirical_ntk_fn_tf( compile times (but OK runtime), is prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model. - TODO(romann): support division between trainable and non-trainable variables. + TODO: support division between trainable and non-trainable variables. - TODO(romann): investigate slow compile times. + TODO: investigate slow compile times. Args: f: - :class:`tf.Module` or :obj:`tf.function` whose NTK we are computing. Must - satisfy the following: + :class:`tf.Module`, :class:`keras.Model`, or :obj:`tf.function` whose NTK + we are computing. Must satisfy the following: - if a :obj:`tf.function`, must have the signature of `f(params, x)`. - - if a :class:`tf.Module`, must be either a :class:`tf.keras.Model`, or - be callable. + - if a :class:`tf.Module`, must be be callable. - input signature (`f.input_shape` for :class:`tf.Module` or - :class:`tf.keras.Model`, or `f.input_signature` for `tf.function`) - must be known. + :class:`keras.Model`, or `f.input_signature` for `tf.function`) must + be known. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace @@ -247,7 +248,7 @@ def empirical_ntk_fn_tf( _s_rules=_s_rules, _fwd=_fwd, ) - if isinstance(f, tf.Module): + if isinstance(f, (tf.Module, keras.Model)): apply_fn, _ = get_apply_fn_and_params(f) elif isinstance(f, tf.types.experimental.PolymorphicFunction): @@ -264,7 +265,7 @@ def empirical_ntk_fn_tf( return ntk_fn -def get_apply_fn_and_params(f: tf.Module): +def get_apply_fn_and_params(f: tf.Module | keras.Model): """Converts a :class:`tf.Module` into a forward-pass `apply_fn` and `params`. Use this function to extract `params` to pass to the Tensorflow empirical NTK @@ -276,16 +277,16 @@ def get_apply_fn_and_params(f: tf.Module): Args: f: - a :class:`tf.Module` to convert to a `apply_fn(params, x)` function. Must - have an `input_shape` attribute set (specifying shape of `x`), and be - callable or be a :class:`tf.keras.Model`. + a callable :class:`tf.Module` or a :class:`keras.Model` to convert to an + `apply_fn(params, x)` function. Must have an `input_shape` attribute set + (specifying shape of `x`). Returns: A tuple fo `(apply_fn, params)`, where `params` is a `PyTree[tf.Tensor]`. """ @tf.function def forward_tf(x: PyTree) -> PyTree: - if isinstance(f, tf.keras.Model): + if isinstance(f, keras.Model): return f.call(x, training=False) if not hasattr(f, '__call__'): diff --git a/neural_tangents/predict.py b/neural_tangents/predict.py index 83f68a62..ab1030c5 100644 --- a/neural_tangents/predict.py +++ b/neural_tangents/predict.py @@ -23,5 +23,5 @@ max_learning_rate, ODEState, - Gaussian + Gaussian, ) diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index ec1e016b..199aad14 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -77,7 +77,7 @@ from ._src.stax.combinators import ( parallel, serial, - repeat + repeat, ) @@ -131,7 +131,7 @@ # Helper object for the `Index` layer. from ._src.stax.linear import ( - Slice + Slice, ) diff --git a/notebooks/empirical_ntk_resnet.ipynb b/notebooks/empirical_ntk_resnet.ipynb index fdf6a036..8a4ac8ed 100644 --- a/notebooks/empirical_ntk_resnet.ipynb +++ b/notebooks/empirical_ntk_resnet.ipynb @@ -85,7 +85,7 @@ "outputs": [], "source": [ "from functools import partial\n", - "from typing import Any, Callable, Sequence, Optional\n", + "from typing import Any, Callable, Sequence\n", "from flax import linen as nn\n", "from jax import jit\n", "from jax import numpy as jnp\n", diff --git a/setup.py b/setup.py index 76d1a13a..08a3fcca 100644 --- a/setup.py +++ b/setup.py @@ -26,17 +26,18 @@ INSTALL_REQUIRES = [ - 'jax>=0.4.16', - 'frozendict>=2.3.8', - 'tensorflow>=2.15.0', - 'tf2jax>=0.3.5', + 'jax>=0.4.34', + 'frozendict>=2.4.6', + 'tensorflow>=2.18.0', + 'keras>=3.6.0', + 'tf2jax>=0.3.6', ] TESTS_REQUIRES = [ - 'more-itertools', - 'tensorflow-datasets', - 'flax>=0.7.2', + 'more-itertools>=10.5.0', + 'tensorflow-datasets>=4.9.6', + 'flax>=0.10.0', ] @@ -116,11 +117,11 @@ def _get_version() -> str: long_description=long_description, long_description_content_type='text/markdown', description='Fast and Easy Infinite Neural Networks in Python', - python_requires='>=3.9', + python_requires='>=3.10', classifiers=[ - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'License :: OSI Approved :: Apache Software License', 'Operating System :: MacOS', 'Operating System :: POSIX :: Linux', diff --git a/tests/batching_test.py b/tests/batching_test.py index 52ed82ea..bf3fdf39 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -17,11 +17,11 @@ from functools import partial from absl.testing import absltest +import jax from jax import jit from jax import config from jax import random import jax.numpy as jnp -from jax.tree_util import tree_map import neural_tangents as nt from neural_tangents import stax from neural_tangents._src import batching @@ -37,7 +37,7 @@ POOLING = 'POOLING' INTERMEDIATE_CONV = 'INTERMEDIATE_CONV' -# TODO(schsam): Add a pooling test when multiple inputs are supported in +# TODO: Add a pooling test when multiple inputs are supported in # Conv + Pooling. TRAIN_SIZES = [2, 4, 8] TEST_SIZES = [2, 16] @@ -89,8 +89,13 @@ def _empirical_kernel(key, input_shape, network, out_logits, use_dropout): return partial(kernel_fn, params=params, keys=split) -def _theoretical_kernel(unused_key, input_shape, network, just_theta, - use_dropout): +def _theoretical_kernel( + unused_key, + input_shape, + network, + just_theta: bool, + use_dropout: bool, +): _, _, _kernel_fn = _build_network(input_shape, network, 1, use_dropout) @jit @@ -118,7 +123,7 @@ def _test_kernel_against_batched( batched_kernel_fn, train, test, - is_parallel_only=False + is_parallel_only=False, ): g = kernel_fn(train, None) g_b = batched_kernel_fn(train, None) @@ -146,7 +151,7 @@ def _get_data_and_kernel_fn( network, test_size, train_size, - **kwargs + **kwargs, ): test_utils.stub_out_pmap(batching, 2) key = random.PRNGKey(0) @@ -163,7 +168,7 @@ def _get_data_and_kernel_fn( input_shape=INPUT_SHAPES, network=NETWORK, kernel_type=list(KERNELS.keys()), - batch_size=[2, 8] + batch_size=[2, 8], ) def testSerial( self, @@ -172,14 +177,14 @@ def testSerial( input_shape, network, kernel_type, - batch_size + batch_size, ): data_other, data_self, kernel_fn = self._get_data_and_kernel_fn( input_shape, kernel_type, network, test_size, - train_size + train_size, ) kernel_batched = batching._serial(kernel_fn, batch_size=batch_size) @@ -222,7 +227,7 @@ def testParallel( input_shape=INPUT_SHAPES, network=NETWORK, kernel_type=list(KERNELS.keys()), - batch_size=[2, 8] + batch_size=[2, 8], ) def testComposition( self, @@ -231,7 +236,7 @@ def testComposition( input_shape, network, kernel_type, - batch_size + batch_size, ): data_other, data_self, kernel_fn = self._get_data_and_kernel_fn(input_shape, kernel_type, @@ -255,7 +260,7 @@ def testComposition( input_shape=INPUT_SHAPES, network=NETWORK, kernel_type=list(KERNELS.keys()), - batch_size=[2, 8] + batch_size=[2, 8], ) def testAutomatic( self, @@ -264,14 +269,14 @@ def testAutomatic( input_shape, network, kernel_type, - batch_size + batch_size, ): data_other, data_self, kernel_fn = self._get_data_and_kernel_fn( input_shape, kernel_type, network, test_size, - train_size + train_size, ) kernel_batched = batching.batch(kernel_fn, batch_size=batch_size) @@ -361,14 +366,16 @@ def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size): def test_jit_or_pmap_broadcast(self): - def kernel_fn(x1, - x2, - do_flip, - keys, - do_square, - params, - _unused=None, - p=0.65): + def kernel_fn( + x1, + x2, + do_flip, + keys, + do_square, + params, + _unused=None, + p=0.65, + ): res = jnp.abs(jnp.matmul(x1, x2)) if do_square: res *= res @@ -407,7 +414,9 @@ def kernel_fn(x1, x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0]) self.assertAllClose( - tree_map(partial(jnp.expand_dims, axis=0), res_1[1]), res_2[1]) + jax.tree.map(partial(jnp.expand_dims, axis=0), res_1[1]), + res_2[1], + ) kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn, device_count=2) @@ -425,7 +434,7 @@ def broadcast(arg): x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0]) self.assertAllClose(res_1[0][1], res_2[0][1]) - self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1]) + self.assertAllClose(jax.tree.map(broadcast, res_1[1]), res_2[1]) @test_utils.product( same_inputs=[True, False] diff --git a/tests/empirical_test.py b/tests/empirical_test.py index 866282f1..10420b13 100644 --- a/tests/empirical_test.py +++ b/tests/empirical_test.py @@ -17,7 +17,7 @@ from functools import partial import logging import operator -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Sequence from absl.testing import absltest from flax import linen as nn @@ -28,7 +28,6 @@ from jax import lax from jax import random from jax import remat -from jax import tree_map import jax.numpy as jnp from jax.tree_util import tree_reduce import neural_tangents as nt @@ -155,7 +154,7 @@ def _f_lin_exact(cls, x0, x, params, do_alter, do_shift_x=True): b *= 2. w1 += 5. w2 /= 0.9 - return tree_map( + return jax.tree.map( operator.add, f0, ({'list': [ @@ -214,7 +213,7 @@ def f_2_exact(x0, x, params, do_alter, do_shift_x=True): w1 += 5. w2 /= 0.9 dx = x - x0 - return tree_map( + return jax.tree.map( operator.add, f_lin, ({'list': [ @@ -252,7 +251,7 @@ def _compare_kernels(self, x1, x2, ntk_fns, ntk_fns_vmapped, nngp_fn): ntk_ref = ntks[nt.NtkImplementation.JACOBIAN_CONTRACTION] - tree_map(lambda x, y: self.assertEqual(x.shape, y.shape), nngp, ntk_ref) + jax.tree.map(lambda x, y: self.assertEqual(x.shape, y.shape), nngp, ntk_ref) for i, ntk in ntks.items(): self.assertAllClose(ntk_ref, ntk, err_msg=f'{i} impl. fails.') @@ -277,7 +276,7 @@ def testNTKAgainstDirect(self, train_test_network, kernel_type): train_shape[1:], network, diagonal_axes=(), - trace_axes=() + trace_axes=(), ) _, ntk_fns_vmapped = kernel_fn( @@ -286,7 +285,7 @@ def testNTKAgainstDirect(self, train_test_network, kernel_type): network, diagonal_axes=(), trace_axes=(), - vmap_axes=0 + vmap_axes=0, ) self._compare_kernels(x1, None, ntk_fns, ntk_fns_vmapped, nngp_fn) @@ -304,7 +303,7 @@ def testNTKAgainstDirect(self, train_test_network, kernel_type): (0, -1), (1, -2), (2, 3), - (3, 0, 2) + (3, 0, 2), ], trace_axes=[ (), @@ -320,8 +319,8 @@ def testNTKAgainstDirect(self, train_test_network, kernel_type): (-3, -2), (-3, -1), (-2, -4), - (2, 0, -1) - ] + (2, 0, -1), + ], ) def testAxes(self, diagonal_axes, trace_axes): key = random.PRNGKey(0) @@ -422,7 +421,7 @@ def layer(N_out): init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1)) - _, params = init_fn(net_key, tree_map(jnp.shape, x1)) + _, params = init_fn(net_key, jax.tree.map(jnp.shape, x1)) ntk_fns = { i: jit(nt.empirical_ntk_fn(apply_fn, implementation=i)) @@ -520,7 +519,7 @@ def get_x(n, k): ) ) - _, params = init_fn(random.PRNGKey(3), tree_map(jnp.shape, x1)) + _, params = init_fn(random.PRNGKey(3), jax.tree.map(jnp.shape, x1)) in_axes = [(0, 1), 2] out_axes = [-2, -3] @@ -791,13 +790,11 @@ def apply_fn(params, x, **kwargs): 'transpose_5': lambda p, x: jnp.transpose(jnp.expand_dims(jnp.stack(p, 2), 2), (0, 1, 2, 3)), 'transpose_6': lambda p, x: jnp.transpose(jnp.expand_dims(jnp.stack(p, 2), 0), (1, 0, 3, 2)), - # pytype: disable=module-attr 'lax._reduce_window_sum_1': lambda p, x: lax._reduce_window_sum(p[0], (1, 2), (1, 1), [(0, 0), (0, 1)]), 'lax._reduce_window_sum_2': lambda p, x: lax._reduce_window_sum(p[0], (1, 1), (1, 1), [(0, 0), (0, 0)]), 'lax._reduce_window_sum_3': lambda p, x: lax._reduce_window_sum(p[0], (2, 1), (1, 2), [(0, 0), (0, 2)]), 'lax._reduce_window_sum_4': lambda p, x: lax._reduce_window_sum(p[0], (2, 2), (1, 1), [(2, 3), (0, 0)]), 'lax._reduce_window_sum_5': lambda p, x: lax._reduce_window_sum(p[0], (1, 1), (2, 1), [(0, 0), (1, 0)]), - # pytype: enable=module-attr 'dg1-l': lambda p, x: lax.dot_general(p[0], x, (((), ()), ((), ()))), 'dg2-l': lambda p, x: lax.dot_general(p[0], x, (((1,), (0,)), ((), ()))), @@ -829,7 +826,7 @@ def apply_fn(params, x, **kwargs): 'p[1] * p[0][1, 0]': lambda p, x: p[1] * p[0][1, 0], 'p[1] / p[0][0, -1]': lambda p, x: p[1] / p[0][1, -1], - # TODO(romann): investigate full support for compiled loops. + # TODO: investigate full support for compiled loops. 'lax.map_1': lambda p, x: lax.map(lambda s: 2 * s, p[0]) * jnp.sum(p[1]), 'lax.map_2': lambda p, x: lax.map(lambda s: 2 * s + 1, p[0]) * jnp.sum(p[0]), 'lax.map_3': lambda p, x: jnp.sum(lax.map(lambda s: -s / 2., p[0])) * p[0], @@ -837,7 +834,7 @@ def apply_fn(params, x, **kwargs): 'lax.map_5': lambda p, x: (lax.map(lambda s: lax.map(lambda p: 2 * p, s) + 1., p[0]), p[1]), 'lax.map_6': lambda p, x: [lax.map(lambda s: lax.map(lambda p: 2 * p, s) + 1., p[0]), p[0]], - # TODO(romann): revisit if JAX figures out AD for out-of-bounds indexing. + # TODO: revisit if JAX figures out AD for out-of-bounds indexing. # 'p[0][1, 0] * p[2].T': lambda p, x: p[0][1, 0] * p[2].T, } @@ -908,7 +905,7 @@ def _compare_ntks( k_2, rtol=rtol, atol=atol, - check_dtypes=False, # TODO(romann): revisit. + check_dtypes=False, # TODO: revisit. check_finite=False, err_msg=msg) @@ -1003,7 +1000,7 @@ class StructuredDerivativesTest(test_utils.NeuralTangentsTestCase): # False ], do_remat=[ - # TODO(romann): support remat + # TODO: support remat # True, False ], @@ -1024,11 +1021,11 @@ def test_function( _fwd ): if f_name == 'lax_reshape_all': - # TODO(romann): investigate slow CPU execution. + # TODO: investigate slow CPU execution. test_utils.skip_test('Skipping large non-structured reshapes on CPU.') if 'lax.map' in f_name and shapes[0][0] and shapes[0][0][0] == 0: - # TODO(romann): fix. + # TODO: fix. raise absltest.SkipTest('Zero-length scans not supported without JIT.') p = [random.normal(random.PRNGKey(i), s, dtype) for i, s in @@ -1311,7 +1308,7 @@ class _MlpMixer(nn.Module): hidden_dim: int tokens_mlp_dim: int channels_mlp_dim: int - model_name: Optional[str] = None + model_name: str | None = None @nn.compact def __call__(self, inputs, *, train): @@ -1369,7 +1366,7 @@ def _get_mixer_b16_config() -> dict[str, Any]: ], dtype=[ jax.dtypes.canonicalize_dtype(jnp.float64), - ] + ], ) class FlaxOtherTest(test_utils.NeuralTangentsTestCase): @@ -1584,7 +1581,7 @@ def test_flax_cnn(self, same_inputs, do_jit, do_remat, dtype, j_rules, vmap_axes=[ 0, None - ] + ], ) class ConvTest(test_utils.NeuralTangentsTestCase): diff --git a/tests/experimental/empirical_ntk_tf_test.py b/tests/experimental/empirical_ntk_tf_test.py index a221f4fc..a6cd4f25 100644 --- a/tests/experimental/empirical_ntk_tf_test.py +++ b/tests/experimental/empirical_ntk_tf_test.py @@ -15,6 +15,7 @@ """Tests for `examples/experimental/empirical_ntk_tf.py`.""" from absl.testing import absltest +import jax from examples.experimental import empirical_ntk_tf @@ -22,7 +23,8 @@ class EmpiricalNtkTfTest(absltest.TestCase): def test_empirical_ntk_tf_test(self): - empirical_ntk_tf.main(None) + with jax.numpy_rank_promotion('warn'): + empirical_ntk_tf.main(None) if __name__ == '__main__': diff --git a/tests/experimental/empirical_tf_test.py b/tests/experimental/empirical_tf_test.py index dbcfef55..abf005c4 100644 --- a/tests/experimental/empirical_tf_test.py +++ b/tests/experimental/empirical_tf_test.py @@ -14,10 +14,12 @@ """Tests for `experimental/empirical_tf/empirical.py`.""" +import platform from absl.testing import absltest from absl.testing import parameterized import jax from jax import numpy as jnp +import keras import neural_tangents as nt from neural_tangents import experimental import numpy as np @@ -104,13 +106,13 @@ def _f4_jax(params, x): # https://github.com/jimmyyhwu/resnet18-tf2/blob/master/resnet.py -_kaiming_normal = tf.keras.initializers.VarianceScaling( +_kaiming_normal = keras.initializers.VarianceScaling( scale=2.0, mode='fan_out', distribution='untruncated_normal') def _conv3x3(x, out_planes, stride=1, name=None): - x = tf.keras.layers.ZeroPadding2D(padding=1, name=f'{name}_pad')(x) - return tf.keras.layers.Conv2D( + x = keras.layers.ZeroPadding2D(padding=1, name=f'{name}_pad')(x) + return keras.layers.Conv2D( filters=out_planes, kernel_size=3, strides=stride, use_bias=False, kernel_initializer=_kaiming_normal, name=name)(x) @@ -119,20 +121,20 @@ def _basic_block(x, planes, stride=1, downsample=None, name=None): identity = x out = _conv3x3(x, planes, stride=stride, name=f'{name}.conv1') - out = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, - name=f'{name}.bn1')(out) - out = tf.keras.layers.ReLU(name=f'{name}.relu1')(out) + out = keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name=f'{name}.bn1')(out) + out = keras.layers.ReLU(name=f'{name}.relu1')(out) out = _conv3x3(out, planes, name=f'{name}.conv2') - out = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, - name=f'{name}.bn2')(out) + out = keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name=f'{name}.bn2')(out) if downsample is not None: for layer in downsample: identity = layer(identity) - out = tf.keras.layers.Add(name=f'{name}.add')([identity, out]) - out = tf.keras.layers.ReLU(name=f'{name}.relu2')(out) + out = keras.layers.Add(name=f'{name}.add')([identity, out]) + out = keras.layers.ReLU(name=f'{name}.relu2')(out) return out @@ -142,12 +144,12 @@ def _make_layer(x, planes, blocks, stride=1, name=None): inplanes = x.shape[3] if stride != 1 or inplanes != planes: downsample = [ - tf.keras.layers.Conv2D( + keras.layers.Conv2D( filters=planes, kernel_size=1, strides=stride, use_bias=False, kernel_initializer=_kaiming_normal, name=f'{name}.0.downsample.0'), - tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, - name=f'{name}.0.downsample.1'), + keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name=f'{name}.0.downsample.1'), ] x = _basic_block(x, planes, stride, downsample, name=f'{name}.0') @@ -158,31 +160,31 @@ def _make_layer(x, planes, blocks, stride=1, name=None): def _resnet(x, blocks_per_layer, classes, filters): - x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(x) - x = tf.keras.layers.Conv2D( + x = keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(x) + x = keras.layers.Conv2D( filters=filters, kernel_size=7, strides=2, use_bias=False, kernel_initializer=_kaiming_normal, name='conv1')(x) - x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, - name='bn1')(x) - x = tf.keras.layers.ReLU(name='relu1')(x) - x = tf.keras.layers.ZeroPadding2D(padding=1, name='maxpool_pad')(x) - x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, name='maxpool')(x) + x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name='bn1')(x) + x = keras.layers.ReLU(name='relu1')(x) + x = keras.layers.ZeroPadding2D(padding=1, name='maxpool_pad')(x) + x = keras.layers.MaxPool2D(pool_size=3, strides=2, name='maxpool')(x) x = _make_layer(x, filters, blocks_per_layer[0], name='layer1') - x = tf.keras.layers.GlobalAveragePooling2D(name='avgpool')(x) - initializer = tf.keras.initializers.RandomUniform(-1.0 / (2 * filters)**0.5, - 1.0 / (2 * filters)**0.5) - x = tf.keras.layers.Dense(units=classes, kernel_initializer=initializer, - bias_initializer=initializer, name='fc')(x) + x = keras.layers.GlobalAveragePooling2D(name='avgpool')(x) + initializer = keras.initializers.RandomUniform(-1.0 / (2 * filters)**0.5, + 1.0 / (2 * filters)**0.5) + x = keras.layers.Dense(units=classes, kernel_initializer=initializer, + bias_initializer=initializer, name='fc')(x) return x def _MiniResNet(classes, input_shape, weights): - inputs = tf.keras.Input(shape=input_shape) + inputs = keras.Input(shape=input_shape) outputs = _resnet(inputs, [1, 1, 1, 1], classes=classes, filters=2) - return tf.keras.Model(inputs=inputs, outputs=outputs) + return keras.Model(inputs=inputs, outputs=outputs) class EmpiricalTfTest(parameterized.TestCase): @@ -220,7 +222,7 @@ def _compare_ntks( for v in vmap_axes if v not in trace_axes + diagonal_axes ] - x_shape = (f.input_shape[1:] if isinstance(f, tf.Module) else + x_shape = (f.input_shape[1:] if isinstance(f, (tf.Module, keras.Model)) else f.input_signature[1].shape[1:]) x1 = tf.random.normal((2,) + x_shape, seed=2) / np.prod(x_shape) ** 0.5 @@ -228,7 +230,7 @@ def _compare_ntks( x1_jax = jnp.array(x1) x2_jax = jnp.array(x2) - params_jax = jax.tree_map(jnp.array, params) + params_jax = jax.tree.map(jnp.array, params) jax_ntks = [ntk_fn_i(x1_jax, x2_jax, params_jax) for ntk_fn_i in jax_ntk_fns] @@ -240,7 +242,7 @@ def _compare_ntks( atol = 0. rtol = 5e-3 atol_jax = 0.4 - rtol_jax = 0.15 # TODO(romann): revisit poor TPU agreement. + rtol_jax = 0.15 # TODO: revisit poor TPU agreement. else: atol = 1e-5 rtol = 1e-4 @@ -258,8 +260,8 @@ def _compare_ntks( @parameterized.product( f=[ _MiniResNet, - # # TODO(romann): MobileNet works, but takes too long to compile. - # tf.keras.applications.MobileNet, + # # TODO: MobileNet works, but takes too long to compile. + # keras.applications.MobileNet, ], input_shape=[ (32, 32, 3) @@ -284,10 +286,13 @@ def test_keras_functional( diagonal_axes, vmap_axes, ): - f = f(classes=1, input_shape=input_shape, weights=None) - f.build((None, *input_shape)) - f_jax, params = experimental.get_apply_fn_and_params(f) - self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) + if platform.system() == 'Darwin': + self.skipTest('TF <-> JAX fails on MacOS.') + with jax.numpy_rank_promotion("warn"): + f = f(classes=1, input_shape=input_shape, weights=None) + f.build((None, *input_shape)) + f_jax, params = experimental.get_apply_fn_and_params(f) + self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) @parameterized.product( input_shape=[ @@ -312,15 +317,18 @@ def test_keras_sequential( diagonal_axes, vmap_axes, ): - f = tf.keras.Sequential() - f.add(tf.keras.layers.Conv2D(4, (3, 3), activation='relu')) - f.add(tf.keras.layers.Conv2D(2, (2, 2), activation='relu')) - f.add(tf.keras.layers.Flatten()) - f.add(tf.keras.layers.Dense(2)) - - f.build((None, *input_shape)) - f_jax, params = experimental.get_apply_fn_and_params(f) - self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) + if platform.system() == 'Darwin': + self.skipTest('TF <-> JAX fails on MacOS.') + with jax.numpy_rank_promotion("warn"): + f = keras.Sequential() + f.add(keras.layers.Conv2D(4, (3, 3), activation='relu')) + f.add(keras.layers.Conv2D(2, (2, 2), activation='relu')) + f.add(keras.layers.Flatten()) + f.add(keras.layers.Dense(2)) + + f.build((None, *input_shape)) + f_jax, params = experimental.get_apply_fn_and_params(f) + self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) @parameterized.product( f_f_jax=[ @@ -352,6 +360,8 @@ def test_tf_function( diagonal_axes, vmap_axes, ): + if platform.system() == 'Darwin': + self.skipTest('TF <-> JAX fails on MacOS.') f, f_jax = f_f_jax f = tf.function(f, input_signature=_input_signature) params = tf.random.normal(params_shape, seed=4) @@ -376,6 +386,8 @@ def test_tf_module( diagonal_axes, vmap_axes, ): + if platform.system() == 'Darwin': + self.skipTest('TF <-> JAX fails on MacOS.') f = _MLP(input_size=5, sizes=[4, 6, 3], name='MLP') f_jax, params = experimental.get_apply_fn_and_params(f) self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) diff --git a/tests/predict_test.py b/tests/predict_test.py index 84d229e7..80a75a1f 100644 --- a/tests/predict_test.py +++ b/tests/predict_test.py @@ -50,7 +50,7 @@ FLAT = 'FLAT' POOLING = 'POOLING' -# TODO(schsam): Add a pooling test when multiple inputs are supported in +# TODO: Add a pooling test when multiple inputs are supported in # Conv + Pooling. TRAIN_SIZES = [4, 8] TEST_SIZES = [6, 2] @@ -114,12 +114,12 @@ def _test_zero_time(self, predictor, fx_train_0, fx_test_0, g_td, momentum): if momentum is not None: # Test state-based prediction - state_0 = predict.ODEState(fx_train_0, fx_test_0) # pytype:disable=wrong-arg-count + state_0 = predict.ODEState(fx_train_0, fx_test_0) # pytype: disable=wrong-arg-count state_t0 = predictor(0.0, state_0, None, g_td) self.assertAllClose(state_0.fx_train, state_t0.fx_train) self.assertAllClose(state_0.fx_test, state_t0.fx_test) - state_train_only_0 = predict.ODEState(fx_train_0) # pytype:disable=wrong-arg-count + state_train_only_0 = predict.ODEState(fx_train_0) # pytype: disable=wrong-arg-count state_train_only_t0 = predictor(0.0, state_0, None, g_td) self.assertAllClose(state_train_only_0.fx_train, state_train_only_t0.fx_train) @@ -153,7 +153,7 @@ def _test_multi_step(self, predictor, fx_train_0, fx_test_0, g_td, momentum): self.assertAllClose(fx_test_concat, fx_test_single) if momentum is not None: - state_0 = predict.ODEState(fx_train_0, fx_test_0) # pytype:disable=wrong-arg-count + state_0 = predict.ODEState(fx_train_0, fx_test_0) # pytype: disable=wrong-arg-count t_1 = (0, 0, 2) state_1 = predictor(ts[t_1], state_0, None, g_td) self.assertAllClose(fx_train_single[t_1], state_1.fx_train) @@ -263,8 +263,7 @@ def testNTKGDPrediction( @classmethod def _cov_empirical(cls, x): - return jnp.einsum('itjk,itlk->tjl', x, x, optimize=True) / (x.shape[0] * # pytype: disable=wrong-arg-types # jnp-type - x.shape[-1]) + return jnp.einsum('itjk,itlk->tjl', x, x) / (x.shape[0] * x.shape[-1]) @test_utils.product( train_size=TRAIN_SIZES[:1], @@ -301,7 +300,7 @@ def testNTKMeanCovPrediction( self.assertGreater(jnp.min(jnp.linalg.eigh(cov_train_inf)[0]), -1e-8) _kernel_fn = nt.empirical_kernel_fn(f) - # TODO(romann): figure out the slow compile on Ubuntu 22.04 CPU Python 3.9 + # TODO: figure out the slow compile on Ubuntu 22.04 CPU Python 3.9 kernel_fn = lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params) def predict_empirical(key): @@ -676,9 +675,11 @@ def train_network(key): mean_emp = jnp.mean(ensemble_fx, axis=0, keepdims=True) mean_subtracted = ensemble_fx - mean_emp - cov_emp = jnp.einsum( # pytype: disable=wrong-arg-types # jnp-type - 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( - mean_subtracted.shape[0] * mean_subtracted.shape[-1]) + cov_emp = jnp.einsum( + 'ijk,ilk->jl', + mean_subtracted, + mean_subtracted, + ) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self.assertAllClose(ravel_pytree(mean_emp)[0], @@ -856,7 +857,7 @@ def testPredictOnCPU(self): def is_on_cpu(x): return jax.tree_util.tree_all( - jax.tree_map( + jax.tree.map( lambda x: 'cpu' in str(x.addressable_shards[0].device).lower(), x, @@ -922,7 +923,7 @@ def testPredictND(self): p_train_mse, p_test_mse = predict_fn_mse( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test_mse.shape) - self.assertAllClose(y_train_shape, p_train_mse.shape) # pytype: disable=attribute-error # jax-ndarray + self.assertAllClose(y_train_shape, p_train_mse.shape) p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble( ts, x, ('nngp', 'ntk'), compute_cov=True) diff --git a/tests/rules_test.py b/tests/rules_test.py index dddc14d0..2193ef0c 100644 --- a/tests/rules_test.py +++ b/tests/rules_test.py @@ -17,7 +17,7 @@ import itertools import logging import random -from typing import Optional, Sequence +from typing import Sequence import warnings from absl.testing import absltest @@ -204,11 +204,11 @@ def _get_f_and_eqn(params, primitive, *inputs): else: if primitive is lax.pad_p: - # TODO(romann): find a way to call primitive.bind directly. + # TODO: find a way to call primitive.bind directly. f = lambda *inputs: lax.pad(*inputs, **params) elif primitive is lax.conv_general_dilated_p: - # TODO(romann): find a way to call primitive.bind directly. + # TODO: find a way to call primitive.bind directly. f = lambda *inputs: lax.conv_general_dilated(*inputs, **params) else: @@ -486,7 +486,7 @@ def _compare_jacobians(self, j_fwd, j_rev, j_rule, primitive): def _test_primitive( self, - primitive: Optional[Primitive], + primitive: Primitive | None, shapes, dtype, params @@ -573,7 +573,7 @@ def _test_primitive( for primitive in _UNARY_PRIMITIVES.keys() for params in _UNARY_PRIMITIVES[primitive](shape, dtype) ) - def test_unary(self, primitive: Optional[Primitive], shape, dtype, params): + def test_unary(self, primitive: Primitive | None, shape, dtype, params): if primitive == lax.device_put_p: # Can't instantiate devices at test generation time; using subtests. devices = [None] + jax.devices() + jax.local_devices(backend='cpu') @@ -602,13 +602,13 @@ def test_unary(self, primitive: Optional[Primitive], shape, dtype, params): ) def test_binary( self, - primitive: Optional[Primitive], + primitive: Primitive | None, shape1, shape2, dtype, params ): - # TODO(romann): revisit when bugs below are fixed. + # TODO: revisit when bugs below are fixed. if primitive == lax.conv_general_dilated_p: if jax.default_backend() == 'tpu': raise absltest.SkipTest('http://b/235167364') @@ -633,7 +633,7 @@ def test_binary( for primitive in _N_ARY_PRIMITIVES.keys() for params in _N_ARY_PRIMITIVES[primitive](*shapes) ) - def test_n_ary(self, primitive: Optional[Primitive], shapes, dtype, params): + def test_n_ary(self, primitive: Primitive | None, shapes, dtype, params): self._test_primitive(primitive, shapes, dtype, params) diff --git a/tests/stax/elementwise_test.py b/tests/stax/elementwise_test.py index 41565be5..2916d42b 100644 --- a/tests/stax/elementwise_test.py +++ b/tests/stax/elementwise_test.py @@ -143,7 +143,7 @@ def kernel_fn(kernels, **kwargs): cov2 = jnp.reshape(cov2, (1, cov2.shape[0])) nngp = kernels.nngp - # TODO(schsam): Update cov1 and cov2 if we want to compose this kernel + # TODO: Update cov1 and cov2 if we want to compose this kernel # with other kernels. return kernels.replace( nngp=jnp.exp(-input_dim * gamma * (cov1 + cov2 - 2 * nngp))) @@ -697,7 +697,7 @@ def d2k(x1, x2): def assert_close(x, y, tol=3e-5): if default_backend() == 'tpu': - # TODO(romann): understand why TPUs have high errors. + # TODO: understand why TPUs have high errors. tol = 0.21 self.assertLess( jnp.max(jnp.abs(x - y)) / (jnp.mean(jnp.abs(x)) + jnp.mean(jnp.abs(y))), diff --git a/tests/stax/linear_test.py b/tests/stax/linear_test.py index eab5a046..017b4029 100644 --- a/tests/stax/linear_test.py +++ b/tests/stax/linear_test.py @@ -46,7 +46,7 @@ @test_utils.product( - same_inputs=[True, False] + same_inputs=[True, False], ) class FlattenTest(test_utils.NeuralTangentsTestCase): diff --git a/tests/stax/requirements_test.py b/tests/stax/requirements_test.py index d22db8ae..1afc22f3 100644 --- a/tests/stax/requirements_test.py +++ b/tests/stax/requirements_test.py @@ -50,7 +50,7 @@ stax.Flatten(), stax.GlobalAvgPool(), stax.Identity() - ] + ], ) class DiagonalTest(test_utils.NeuralTangentsTestCase): @@ -132,7 +132,7 @@ def test_diagonal_compose_is_associative(self): @test_utils.product( - same_inputs=[True, False] + same_inputs=[True, False], ) class InputReqTest(test_utils.NeuralTangentsTestCase): diff --git a/tests/stax/stax_test.py b/tests/stax/stax_test.py index fba80e0d..b71f8fd5 100644 --- a/tests/stax/stax_test.py +++ b/tests/stax/stax_test.py @@ -87,14 +87,14 @@ def _get_inputs( key, same_inputs, shape, - fn=jnp.cos -) -> tuple[jnp.ndarray, jnp.ndarray]: + fn=jnp.cos, +) -> tuple[jnp.ndarray, jnp.ndarray | None]: key, split = random.split(key) x1 = fn(random.normal(key, shape)) batch_axis = shape.index(BATCH_SIZE) shape = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:] x2 = None if same_inputs else fn(random.normal(split, shape)) * 2 - return x1, x2 # pytype: disable=bad-return-type # jax-ndarray + return x1, x2 def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, @@ -300,7 +300,7 @@ def _check_agreement_with_empirical( use_dropout, is_ntk, rtol=RTOL, - atol=ATOL + atol=ATOL, ): ((init_fn, apply_fn, kernel_fn), input_shape, device_count, channel_axis) = net @@ -595,7 +595,7 @@ def test_sparse_inputs(self, act, kernel, do_stabilize): input_size = 3 width = 1024 - # NOTE(schsam): It seems that convergence is slower when inputs are sparse. + # NOTE: It seems that convergence is slower when inputs are sparse. samples = N_SAMPLES if default_backend() == 'gpu': diff --git a/tests/test_utils.py b/tests/test_utils.py index 49e592c0..675682fa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,7 +19,7 @@ import logging import os from types import ModuleType -from typing import Callable, Optional, Sequence +from typing import Callable, Sequence from absl import flags from absl.testing import parameterized @@ -36,13 +36,13 @@ 'nt_test_dut', '', help= - 'Describes the device under test in case special consideration is required.' + 'Describes the device under test in case special consideration is required.', ) flags.DEFINE_integer( 'nt_num_generated_cases', int(os.getenv('NT_NUM_GENERATED_CASES', '4')), - help='Number of generated cases to test' + help='Number of generated cases to test', ) FLAGS = flags.FLAGS @@ -103,9 +103,9 @@ def _default_tolerance() -> dict[np.dtype, float]: def _assert_numpy_allclose( a: np.ndarray, b: np.ndarray, - atol: Optional[float] = None, - rtol: Optional[float] = None, - err_msg: str = '' + atol: float | None = None, + rtol: float | None = None, + err_msg: str = '', ): if a.dtype == b.dtype == _dtypes.float0: np.testing.assert_array_equal(a, b, err_msg=err_msg) @@ -121,7 +121,7 @@ def _assert_numpy_allclose( np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) -def _tolerance(dtype: np.dtype, tol: Optional[float] = None) -> float: +def _tolerance(dtype: np.dtype, tol: float | None = None) -> float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol @@ -239,8 +239,8 @@ def _assertAllClose( y, *, check_dtypes: bool = True, - atol: Optional[float] = None, - rtol: Optional[float] = None, + atol: float | None = None, + rtol: float | None = None, canonicalize_dtypes: bool = True, err_msg: str = '' ): @@ -278,8 +278,8 @@ def assertArraysAllClose( y, *, check_dtypes: bool = True, - atol: Optional[float] = None, - rtol: Optional[float] = None, + atol: float | None = None, + rtol: float | None = None, err_msg: str = '' ): """Assert that x and y are close (up to numerical tolerances).""" @@ -304,8 +304,8 @@ def assertAllClose( y, *, check_dtypes: bool = True, - atol: Optional[float] = None, - rtol: Optional[float] = None, + atol: float | None = None, + rtol: float | None = None, canonicalize_dtypes: bool = True, check_finite: bool = True, err_msg=''): @@ -313,8 +313,8 @@ def assertAllClose( def is_finite(x): self.assertTrue(jnp.all(jnp.isfinite(x))) - jax.tree_map(is_finite, x) - jax.tree_map(is_finite, y) + jax.tree.map(is_finite, x) + jax.tree.map(is_finite, y) def assert_close(x, y): self._assertAllClose( @@ -353,7 +353,7 @@ def update_test_tolerance(f32_tol: float = 5e-3, f64_tol: float = 1e-5): def stub_out_pmap(batch: ModuleType, count: int): - # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core. + # If we are using GPU or CPU stub out pmap with vmap to simulate multicore. if count > 0: class xla_bridge_stub: @@ -371,7 +371,7 @@ def _log( absolute_error: float, expected, actual, - did_pass: bool + did_pass: bool, ): msg = 'PASSED' if did_pass else 'FAILED' logging.info(f'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n' @@ -399,7 +399,7 @@ def assert_close(expected, actual): if (jnp.isnan(relative_error) or relative_error > rtol or absolute_error > atol): - _log(relative_error, absolute_error, expected, actual, False) # pytype: disable=wrong-arg-types # jnp-type + _log(relative_error, absolute_error, expected, actual, False) self.fail(self.failureException('Relative ERROR: ', float(relative_error), 'EXPECTED:' + ' ' * 50, @@ -410,15 +410,15 @@ def assert_close(expected, actual): 'Absolute ERROR: ', float(absolute_error))) else: - _log(relative_error, absolute_error, expected, actual, True) # pytype: disable=wrong-arg-types # jnp-type + _log(relative_error, absolute_error, expected, actual, True) - jax.tree_map(assert_close, expected, actual) + jax.tree.map(assert_close, expected, actual) def skip_test( self, msg: str = 'Skipping large tests for speed.', - platforms: tuple[str, ...] = ('cpu',) + platforms: tuple[str, ...] = ('cpu',), ): if jax.default_backend() in platforms: raise parameterized.TestCase.skipTest(self, msg) @@ -426,10 +426,10 @@ def skip_test( def mask( x: jnp.ndarray, - mask_constant: Optional[float], + mask_constant: float | None, mask_axis: Sequence[int], key: jax.Array, - p: float + p: float, ) -> jnp.ndarray: if mask_constant is not None: mask_shape = [1 if i in mask_axis else s