From 1372a5626d1c062d64e9a14ad009d7abee6e7c5c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 17 Jan 2020 16:58:12 -0500 Subject: [PATCH] Minimal changes to make Jax pass a pytype check. --- jax/BUILD | 22 +++++++++++++++------- jax/ad_util.py | 3 --- jax/api.py | 21 +++++++++------------ jax/core.py | 2 +- jax/experimental/loops.py | 4 +++- jax/interpreters/ad.py | 4 +++- jax/interpreters/masking.py | 4 ++-- jax/interpreters/pxla.py | 25 ++++++++++++------------- jax/interpreters/xla.py | 9 ++++++--- jax/lax/lax.py | 9 +++++---- jax/lax/lax_control_flow.py | 12 +++++++----- jax/lax/lax_parallel.py | 12 +++++++----- jax/lazy.py | 5 ++++- jax/lib/__init__.py | 2 +- jax/linear_util.py | 4 ++-- jax/numpy/lax_numpy.py | 1 + jax/numpy/linalg.py | 5 +++-- tests/api_test.py | 6 +++--- tests/batching_test.py | 1 + tests/lax_numpy_test.py | 9 ++++++--- tests/lax_test.py | 11 ++++++++--- tests/masking_test.py | 2 +- tests/multibackend_test.py | 10 +++++----- tests/parallel_test.py | 2 +- tests/random_test.py | 1 + 25 files changed, 107 insertions(+), 79 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 4485599bd650..3e3c7f1e97b2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -14,13 +14,15 @@ # JAX is Autograd and XLA +pytype_library = py_library + licenses(["notice"]) package(default_visibility = ["//visibility:public"]) # top-level EF placeholder -py_library( +pytype_library( name = "jax", srcs = glob( [ @@ -43,38 +45,44 @@ py_library( deps = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"], ) -py_library( +pytype_library( name = "stax", srcs = ["experimental/stax.py"], + srcs_version = "PY3", deps = [":jax"], ) -py_library( +pytype_library( name = "optimizers", srcs = ["experimental/optimizers.py"], + srcs_version = "PY3", deps = [":jax"], ) -py_library( +pytype_library( name = "optix", srcs = ["experimental/optix.py"], + srcs_version = "PY3", deps = [":jax"], ) -py_library( +pytype_library( name = "ode", srcs = ["experimental/ode.py"], + srcs_version = "PY3", deps = [":jax"], ) -py_library( +pytype_library( name = "vectorize", srcs = ["experimental/vectorize.py"], + srcs_version = "PY3", deps = [":jax"], ) -py_library( +pytype_library( name = "loops", srcs = ["experimental/loops.py"], + srcs_version = "PY3", deps = [":jax"], ) diff --git a/jax/ad_util.py b/jax/ad_util.py index a98d102aa57c..00b9230aa3af 100644 --- a/jax/ad_util.py +++ b/jax/ad_util.py @@ -39,9 +39,6 @@ def add_impl(xs, ys): def add_abstract(xs, ys): return lattice_join(xs, ys) -def zeros_like_impl_jaxtuple(xs): - return JaxTuple(map(zeros_like_impl, xs)) - jaxval_zeros_likers = {} def zeros_like_aval(aval): diff --git a/jax/api.py b/jax/api.py index 7250119825ba..c52692618bcd 100644 --- a/jax/api.py +++ b/jax/api.py @@ -938,7 +938,7 @@ def f_pmapped(*args, **kwargs): msg = ("soft_pmap mapped axis size must be divisble by the number of " "XLA devices (or be less than or equal to that number), but got " "an axis size of {} with {} devices.") - raise ValueError(msg.format(axis_size, pxla.pxla.unmapped_device_count())) + raise ValueError(msg.format(axis_size, pxla.unmapped_device_count())) num_chunks = axis_size // chunk_size reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat] @@ -1922,15 +1922,12 @@ def jaxpr_to_graphviz(jaxpr, consts): fragment.extend(map(constant_node, jaxpr.constvars, consts)) for eqn in jaxpr.eqns: - if eqn.destructure: - id_name = next(id_names) - fragment.append(function_node(id_name, eqn.primitive.name)) - fragment.extend(edge(invar, id_name) for invar in eqn.invars) - fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars) - else: - fragment.append(function_node(eqn.outvars[0], eqn.primitive.name)) - fragment.extend(edge(invar, eqn.outvars[0]) for invar in eqn.invars) - fragment.append(outvar_node(jaxpr.outvar, "out")) + id_name = next(id_names) + fragment.append(function_node(id_name, eqn.primitive.name)) + fragment.extend(edge(invar, id_name) for invar in eqn.invars) + fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars) + for ov in jaxpr.outvars: + fragment.append(outvar_node(ov, "out")) return graph(''.join(fragment)) edge = '{} -> {} [color=gray30];\n'.format @@ -1944,8 +1941,8 @@ def jaxpr_to_graphviz(jaxpr, consts): @wraps(fun) def graphviz_maker(*args, **kwargs): wrapped = lu.wrap_init(fun, kwargs) - jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args)) - jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees) + jax_args, in_tree = tree_flatten((args, kwargs)) + jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree) pvals = map(pv_like, jax_args) jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals) return jaxpr_to_graphviz(jaxpr, consts) diff --git a/jax/core.py b/jax/core.py index 1f237d370c80..93d755222552 100644 --- a/jax/core.py +++ b/jax/core.py @@ -136,7 +136,7 @@ def __eq__(self, other): def __repr__(self): if self.hash is None: - return 'Literal(val={}, hashable={})'.format(self.val, self.hashable) + return 'Literal(val={})'.format(self.val) else: return '{}'.format(self.val) diff --git a/jax/experimental/loops.py b/jax/experimental/loops.py index 1f2df920523a..eefffbbf1807 100644 --- a/jax/experimental/loops.py +++ b/jax/experimental/loops.py @@ -112,6 +112,7 @@ def loop_body(i, acc_arr): import itertools import numpy as onp import traceback +from typing import Any, List, cast from jax import abstract_arrays from jax import lax, core @@ -287,7 +288,8 @@ def __init__(self, scope, loop_builder): self.loop_builder = loop_builder self.first_iteration = True # If we are tracing the first iteration # Stack trace, without this line and the s.range function - self.stack = traceback.StackSummary.from_list(traceback.extract_stack()[:-2]) + self.stack = traceback.StackSummary.from_list( + cast(List[Any], traceback.extract_stack()[:-2])) # Next are state kept from the start of the first iteration to the end of the iteration. self.carried_state_initial = {} diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 0c7016dba3e2..d83dd51aefa8 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -18,6 +18,7 @@ import functools import itertools as it +from typing import Any from . import partial_eval as pe from .. import core as core @@ -36,13 +37,14 @@ def identity(x): return x -def jvp(fun, has_aux=False, instantiate=True): +def jvp(fun, has_aux=False, instantiate=True) -> Any: if not has_aux: return jvpfun(jvp_subtrace(fun), instantiate) else: fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate), aux + @lu.transformation def jvpfun(instantiate, primals, tangents): with new_master(JVPTrace) as master: diff --git a/jax/interpreters/masking.py b/jax/interpreters/masking.py index 366ca73a9b1e..d8c6de5b64af 100644 --- a/jax/interpreters/masking.py +++ b/jax/interpreters/masking.py @@ -142,11 +142,11 @@ def __rsub__(self, other): return self + -other def __floordiv__(self, divisor): - q, _ = divmod(self, divisor) + q, _ = divmod(self, divisor) # pytype: disable=wrong-arg-types return q def __mod__(self, divisor): - _, r = divmod(self, divisor) + _, r = divmod(self, divisor) # pytype: disable=wrong-arg-types return r def __divmod__(self, divisor): diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index eba0341a99d8..a29ce07b7748 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -286,19 +286,7 @@ def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name): ### lazy device-memory persistence and result handling -class ShardedDeviceValue(xla.DeviceValue): - def _check_if_deleted(self): - if self.device_buffers is None: - raise ValueError("ShardedDeviceValue has been deleted.") - - def block_until_ready(self): - self._check_if_deleted() - for buf in self.device_buffers: - buf.block_host_until_ready() - return self - - -class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray): +class ShardedDeviceArray(xla.DeviceArray): """A ShardedDeviceArray is an ndarray sharded across devices. The purpose of a ShardedDeviceArray is to reduce the number of transfers when @@ -346,6 +334,16 @@ def delete(self): self.device_buffers = None self._npy_value = None + def _check_if_deleted(self): + if self.device_buffers is None: + raise ValueError("ShardedDeviceArray has been deleted.") + + def block_until_ready(self): + self._check_if_deleted() + for buf in self.device_buffers: + buf.block_host_until_ready() + return self + @property def _value(self): if self._npy_value is None: @@ -757,6 +755,7 @@ def aval(self): if self.axis_name is not_mapped: return aval else: + assert isinstance(aval, ShapedArray) return ShapedArray(aval.shape[1:], aval.dtype) def full_lower(self): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index c5801c475a3e..fa212ada1f3e 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -825,9 +825,9 @@ def __array__(self, dtype=None, context=None): __str__ = partialmethod(_forward_to_value, str) __bool__ = __nonzero__ = partialmethod(_forward_to_value, bool) - __float__ = partialmethod(_forward_to_value, float) - __int__ = partialmethod(_forward_to_value, int) - __complex__ = partialmethod(_forward_to_value, complex) + def __float__(self): return self._value.__float__() + def __int__(self): return self._value.__int__() + def __complex__(self): return self._value.__complex__() __hex__ = partialmethod(_forward_to_value, hex) __oct__ = partialmethod(_forward_to_value, oct) __index__ = partialmethod(_forward_to_value, op.index) @@ -841,6 +841,9 @@ def __eq__(self, other): return self._value == other def __hash__(self): raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.") + # The following methods are dynamically overridden in lax_numpy.py. + def __getitem__(self, i): raise NotImplementedError + class DeletedBuffer(object): pass deleted_buffer = DeletedBuffer() diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 6b7e537fa065..61d8096ff0d2 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -23,6 +23,7 @@ import itertools import operator import string +from typing import Any import warnings import numpy as onp @@ -678,7 +679,7 @@ def select(pred, on_true, on_false): """ return select_p.bind(pred, on_true, on_false) -def slice(operand, start_indices, limit_indices, strides=None): +def slice(operand: Any, start_indices, limit_indices, strides=None): """Wraps XLA's `Slice `_ operator. @@ -4180,8 +4181,8 @@ def infeed(token, shape=None): flat_shapes, treedef = pytree.flatten(shape) for shape in flat_shapes: if not isinstance(shape, ShapedArray): - raise TypeError("shapes argument to infeed must be a pytree of " - "ShapedArray values, got {}".format(shapes)) + raise TypeError("shape argument to infeed must be a pytree of " + "ShapedArray values, got {}".format(shape)) xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes)) return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1]) @@ -4387,7 +4388,7 @@ def _dynamic_slice_indices(operand, start_indices): if len(start_indices) != operand.ndim: msg = ("Length of slice indices must match number of operand dimensions ({} " "vs {})") - raise ValueError(msg.format(len(start_indices, operand.shape))) + raise ValueError(msg.format(len(start_indices), operand.shape)) # map int over operand.shape to raise any dynamic-shape errors return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i) for i, d in zip(start_indices, operand.shape)] diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 645367c83c68..656fde235c46 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -362,8 +362,9 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr=body_jvp_rearranged) out_carry, out_carry_dot = split_list(out, [num_carry]) - out_tangents = iter(out_carry_dot) - out_tangents = [next(out_tangents) if nz else ad_util.zero for nz in nonzeros_out] + out_tangents_iter = iter(out_carry_dot) + out_tangents = [next(out_tangents_iter) if nz else ad_util.zero + for nz in nonzeros_out] return out_carry, out_tangents while_p = lax.Primitive('while') @@ -701,8 +702,9 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry, carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys - tangents_out = iter(carry_dot + ys_dot) - tangents_out = [next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out] + tangents_out_iter = iter(carry_dot + ys_dot) + tangents_out = [next(tangents_out_iter) if nz else ad_util.zero + for nz in nonzeros_out] return primals_out, tangents_out def _prune_zeros(ts): @@ -919,7 +921,7 @@ def _scan_shape_rule(shapes, forward, length, jaxpr, num_consts, num_carry, linear): const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) - ys_shapes = [tuple(length, *y_aval.shape) for y_aval in y_avals] + ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals] return init_shexprs + ys_shapes def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length, diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 0aeb19524edf..19076e3a6243 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -619,12 +619,14 @@ def _add_jaxvals_papply_rule(name, size, vals, dims): xdim, ydim = dims if xdim == ydim: out_dim = xdim - elif ydim is None: - y = lax.psplit_like(y, x, name) - out_dim = xdim else: - x = lax.psplit_like(x, y, name) - out_dim = ydim + raise NotImplementedError + # elif ydim is None: + # y = lax.psplit_like(y, x, name) + # out_dim = xdim + # else: + # x = lax.psplit_like(x, y, name) + # out_dim = ydim return ad_util.add_jaxvals_p.bind(x, y), out_dim diff --git a/jax/lazy.py b/jax/lazy.py index 8115c7bc0d24..2dca052d5272 100644 --- a/jax/lazy.py +++ b/jax/lazy.py @@ -19,6 +19,7 @@ from collections import namedtuple import functools import operator as op +from typing import Any, Callable import numpy as onp @@ -32,7 +33,7 @@ ### util # TODO(mattjj): replace with dataclass when Python 2 support is removed -def taggedtuple(name, fields): +def taggedtuple(name, fields) -> Callable[..., Any]: """Lightweight version of namedtuple where equality depends on the type.""" def __new__(cls, *xs): return tuple.__new__(cls, (cls,) + xs) @@ -99,12 +100,14 @@ def __str__(self): # hash(A(1, 2)) == hash(B(1, 2)) # True # but we want hashes to be sensitive to the type tag (while still being fast). +# pytype: disable=wrong-arg-count LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims']) ArrayVar = taggedtuple('ArrayVar', []) Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N) Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays +# pytype: enable=wrong-arg-count def array(shape): return LazyExpr(ArrayVar(), shape, tuple(range(len(shape)))) diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index 55f97888aa55..26939002fe8e 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -45,7 +45,7 @@ def _check_jaxlib_version(): try: - from jaxlib import tpu_client + from jaxlib import tpu_client # pytype: disable=import-error except: tpu_client = None from jaxlib import xla_client diff --git a/jax/linear_util.py b/jax/linear_util.py index 3be976214014..b7f191c9ec3a 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -175,7 +175,7 @@ def __eq__(self, other): self.params == other.params) @curry -def transformation(gen, fun, *gen_static_args): +def transformation(gen, fun: WrappedFun, *gen_static_args): """Adds one more transformation to a WrappedFun. Args: gen: the transformation generator function @@ -185,7 +185,7 @@ def transformation(gen, fun, *gen_static_args): return fun.wrap(gen, gen_static_args, None) @curry -def transformation_with_aux(gen, fun, *gen_static_args): +def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args): """Adds one more transformation with auxiliary output to a WrappedFun.""" out_store = Store() out_thunk = lambda: out_store.val diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 0b5ae04125c3..b51273982af9 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pytype: skip-file """ Implements the NumPy API, using the primitives in :mod:`jax.lax`. diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index f4fc5070df9a..99f3c5df3c7c 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -21,6 +21,7 @@ import numpy as onp import warnings import textwrap +from typing import Tuple, Union, cast from jax import jit from .. import lax @@ -173,7 +174,7 @@ def inv(a): @partial(jit, static_argnums=(1, 2, 3)) -def _norm(x, ord, axis, keepdims): +def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims): x = _promote_arg_dtypes(np.asarray(x)) x_shape = np.shape(x) ndim = len(x_shape) @@ -214,7 +215,7 @@ def _norm(x, ord, axis, keepdims): return np.power(out, 1. / ord) elif num_axes == 2: - row_axis, col_axis = axis + row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return np.sqrt(np.sum(np.real(x * np.conj(x)), axis=axis, keepdims=keepdims)) diff --git a/tests/api_test.py b/tests/api_test.py index 38351853ba42..39f53f81a049 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -527,11 +527,11 @@ def f(x, y): return x + y self.assertRaisesRegex( TypeError, "primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.", - lambda: partial(api.jvp(f, 0., (1.,)))) + lambda: api.jvp(f, 0., (1.,))) self.assertRaisesRegex( TypeError, "primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.", - lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.])))) + lambda: api.jvp(f, (0.,), onp.array([1., 2.]))) def test_vjp_mismatched_arguments(self): _, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4)) @@ -1712,7 +1712,7 @@ def make_computation_builder_and_count(*args, **kwargs): @jtu.skip_on_devices("tpu") def test_lazy_jit_closed_over_values(self): if not core.skip_checks: - raise SkipTest("oom test skipped when core.skip_checks is False") + raise unittest.SkipTest("oom test skipped when core.skip_checks is False") y = np.arange(int(1e12)) # will likely oom if materialized ans = jit(lambda x: (x + y)[1])(1) diff --git a/tests/batching_test.py b/tests/batching_test.py index 9a8fd0d2e71e..0d73bd02c522 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized +import jax import jax.numpy as np from jax import test_util as jtu from jax.abstract_arrays import ShapedArray diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7fc1825b318a..9dbe3bdcc5db 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -21,6 +21,7 @@ from functools import partial import itertools import operator +from typing import cast, Optional import unittest from unittest import SkipTest import warnings @@ -493,7 +494,7 @@ def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, op_tolerance): if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: - raise SkipTest() # TODO(mattjj): clean up + raise SkipTest("scalars not implemented") # TODO(mattjj): clean up rng = rng_factory() args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) fun = lambda fst, snd: getattr(snd, name)(fst) @@ -1815,7 +1816,8 @@ def testRoll(self, shape, dtype, shifts, axis, rng_factory): "index_dtype": index_dtype, "axis": axis, "mode": mode} for shape in [(3,), (3, 4), (3, 4, 5)] for index_shape in scalar_shapes + [(3,), (2, 1, 3)] - for axis in itertools.chain(range(-len(shape), len(shape)), [None]) + for axis in itertools.chain(range(-len(shape), len(shape)), + [cast(Optional[int], None)]) for dtype in all_dtypes for index_dtype in int_dtypes for mode in ['wrap', 'clip'] @@ -1844,7 +1846,8 @@ def args_maker(): _shapes_are_equal_length, filter(_shapes_are_broadcast_compatible, CombosWithReplacement(nonempty_nonscalar_array_shapes, 2))) - for axis in itertools.chain(range(len(x_shape)), [-1], [None]) + for axis in itertools.chain(range(len(x_shape)), [-1], + [cast(Optional[int], None)]) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng_factory): diff --git a/tests/lax_test.py b/tests/lax_test.py index aadba64d203a..25ff95f5eb55 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -20,6 +20,7 @@ import functools from functools import partial import itertools +from typing import Optional, cast from unittest import skip, SkipTest from absl.testing import absltest @@ -28,6 +29,7 @@ import numpy as onp import numpy.random as npr +import jax from jax import api from jax import core from jax import dtypes @@ -2524,7 +2526,8 @@ def testRemainder(self): def all_bdims(*shapes): - bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes) + bdims = (itertools.chain([cast(Optional[int], None)], + range(len(shape) + 1)) for shape in shapes) return (t for t in itertools.product(*bdims) if not all(e is None for e in t)) def add_bdim(bdim_size, bdim, shape): @@ -2603,8 +2606,10 @@ def testOp(self, op_name, rng_factory, shapes, dtype, bdims): (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] - for lhs_bdim in itertools.chain([None], range(len(lhs_shape) + 1)) - for rhs_bdim in itertools.chain([None], range(len(rhs_shape) + 1)) + for lhs_bdim in itertools.chain([cast(Optional[int], None)], + range(len(lhs_shape) + 1)) + for rhs_bdim in itertools.chain([cast(Optional[int], None)], + range(len(rhs_shape) + 1)) if (lhs_bdim, rhs_bdim) != (None, None) for rng_factory in [jtu.rand_default] )) diff --git a/tests/masking_test.py b/tests/masking_test.py index dfa9a6ddd297..ee920d2908b6 100644 --- a/tests/masking_test.py +++ b/tests/masking_test.py @@ -437,7 +437,7 @@ def step(h, x): ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W) def rnn_reference(W, seqs, targets): - total_loss = 0 + total_loss = np.array(0, np.float_) for xs, target in zip(seqs, targets): h = np.zeros(n) for x in xs: diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 7fc2aacb2834..81b0c3328dad 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -46,7 +46,7 @@ class MultiBackendTest(jtu.JaxTestCase): )) def testMultiBackend(self, backend): if backend not in ('cpu', jtu.device_under_test(), None): - raise SkipTest() + raise SkipTest("Backend is not CPU or the device under test") @partial(api.jit, backend=backend) def fun(x, y): return np.matmul(x, y) @@ -65,7 +65,7 @@ def fun(x, y): def testMultiBackendNestedJit(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): - raise SkipTest() + raise SkipTest("Backend is not CPU or the device under test") @partial(api.jit, backend=outer) def fun(x, y): @partial(api.jit, backend=inner) @@ -91,9 +91,9 @@ def infun(x, y): def testMultiBackendNestedJitConflict(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): - raise SkipTest() + raise SkipTest("Backend is not CPU or the device under test") if inner not in ('cpu', jtu.device_under_test(), None): - raise SkipTest() + raise SkipTest("Backend is not CPU or the device under test") @partial(api.jit, backend=outer) def fun(x, y): @partial(api.jit, backend=inner) @@ -111,7 +111,7 @@ def infun(x, y): )) def testGpuMultiBackendOpByOpReturn(self, backend): if backend not in ('cpu', jtu.device_under_test()): - raise SkipTest() + raise SkipTest("Backend is not CPU or the device under test") @partial(api.jit, backend=backend) def fun(x, y): return np.matmul(x, y) diff --git a/tests/parallel_test.py b/tests/parallel_test.py index d8414f49755c..449bcb0bf65f 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -309,7 +309,7 @@ def f(x, y): pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x, y) except (NotImplementedError, TypeError) as e: - raise SkipTest(e) + raise SkipTest(str(e)) ans = self.dedup(ans, expected.ndim) self.assertAllClose(ans, expected, check_dtypes=False) diff --git a/tests/random_test.py b/tests/random_test.py index 4c3299b1170c..66b400e50fe4 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized import numpy as onp +import scipy.linalg import scipy.special import scipy.stats