Skip to content

Minimal changes to make Jax pass a pytype check. #2024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

# JAX is Autograd and XLA

pytype_library = py_library

licenses(["notice"])

package(default_visibility = ["//visibility:public"])

# top-level EF placeholder

py_library(
pytype_library(
name = "jax",
srcs = glob(
[
Expand All @@ -43,38 +45,44 @@ py_library(
deps = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
)

py_library(
pytype_library(
name = "stax",
srcs = ["experimental/stax.py"],
srcs_version = "PY3",
deps = [":jax"],
)

py_library(
pytype_library(
name = "optimizers",
srcs = ["experimental/optimizers.py"],
srcs_version = "PY3",
deps = [":jax"],
)

py_library(
pytype_library(
name = "optix",
srcs = ["experimental/optix.py"],
srcs_version = "PY3",
deps = [":jax"],
)

py_library(
pytype_library(
name = "ode",
srcs = ["experimental/ode.py"],
srcs_version = "PY3",
deps = [":jax"],
)

py_library(
pytype_library(
name = "vectorize",
srcs = ["experimental/vectorize.py"],
srcs_version = "PY3",
deps = [":jax"],
)

py_library(
pytype_library(
name = "loops",
srcs = ["experimental/loops.py"],
srcs_version = "PY3",
deps = [":jax"],
)
3 changes: 0 additions & 3 deletions jax/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def add_impl(xs, ys):
def add_abstract(xs, ys):
return lattice_join(xs, ys)

def zeros_like_impl_jaxtuple(xs):
return JaxTuple(map(zeros_like_impl, xs))

jaxval_zeros_likers = {}

def zeros_like_aval(aval):
Expand Down
21 changes: 9 additions & 12 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def f_pmapped(*args, **kwargs):
msg = ("soft_pmap mapped axis size must be divisble by the number of "
"XLA devices (or be less than or equal to that number), but got "
"an axis size of {} with {} devices.")
raise ValueError(msg.format(axis_size, pxla.pxla.unmapped_device_count()))
raise ValueError(msg.format(axis_size, pxla.unmapped_device_count()))
num_chunks = axis_size // chunk_size

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
Expand Down Expand Up @@ -1922,15 +1922,12 @@ def jaxpr_to_graphviz(jaxpr, consts):
fragment.extend(map(constant_node, jaxpr.constvars, consts))

for eqn in jaxpr.eqns:
if eqn.destructure:
id_name = next(id_names)
fragment.append(function_node(id_name, eqn.primitive.name))
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
else:
fragment.append(function_node(eqn.outvars[0], eqn.primitive.name))
fragment.extend(edge(invar, eqn.outvars[0]) for invar in eqn.invars)
fragment.append(outvar_node(jaxpr.outvar, "out"))
id_name = next(id_names)
fragment.append(function_node(id_name, eqn.primitive.name))
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
for ov in jaxpr.outvars:
fragment.append(outvar_node(ov, "out"))
return graph(''.join(fragment))

edge = '{} -> {} [color=gray30];\n'.format
Expand All @@ -1944,8 +1941,8 @@ def jaxpr_to_graphviz(jaxpr, consts):
@wraps(fun)
def graphviz_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun, kwargs)
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
pvals = map(pv_like, jax_args)
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return jaxpr_to_graphviz(jaxpr, consts)
Expand Down
2 changes: 1 addition & 1 deletion jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __eq__(self, other):

def __repr__(self):
if self.hash is None:
return 'Literal(val={}, hashable={})'.format(self.val, self.hashable)
return 'Literal(val={})'.format(self.val)
else:
return '{}'.format(self.val)

Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def loop_body(i, acc_arr):
import itertools
import numpy as onp
import traceback
from typing import Any, List, cast

from jax import abstract_arrays
from jax import lax, core
Expand Down Expand Up @@ -287,7 +288,8 @@ def __init__(self, scope, loop_builder):
self.loop_builder = loop_builder
self.first_iteration = True # If we are tracing the first iteration
# Stack trace, without this line and the s.range function
self.stack = traceback.StackSummary.from_list(traceback.extract_stack()[:-2])
self.stack = traceback.StackSummary.from_list(
cast(List[Any], traceback.extract_stack()[:-2]))

# Next are state kept from the start of the first iteration to the end of the iteration.
self.carried_state_initial = {}
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import functools
import itertools as it
from typing import Any

from . import partial_eval as pe
from .. import core as core
Expand All @@ -36,13 +37,14 @@
def identity(x): return x


def jvp(fun, has_aux=False, instantiate=True):
def jvp(fun, has_aux=False, instantiate=True) -> Any:
if not has_aux:
return jvpfun(jvp_subtrace(fun), instantiate)
else:
fun, aux = jvp_subtrace_aux(fun)
return jvpfun(fun, instantiate), aux


@lu.transformation
def jvpfun(instantiate, primals, tangents):
with new_master(JVPTrace) as master:
Expand Down
4 changes: 2 additions & 2 deletions jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def __rsub__(self, other):
return self + -other

def __floordiv__(self, divisor):
q, _ = divmod(self, divisor)
q, _ = divmod(self, divisor) # pytype: disable=wrong-arg-types
return q

def __mod__(self, divisor):
_, r = divmod(self, divisor)
_, r = divmod(self, divisor) # pytype: disable=wrong-arg-types
return r

def __divmod__(self, divisor):
Expand Down
25 changes: 12 additions & 13 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,19 +286,7 @@ def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):

### lazy device-memory persistence and result handling

class ShardedDeviceValue(xla.DeviceValue):
def _check_if_deleted(self):
if self.device_buffers is None:
raise ValueError("ShardedDeviceValue has been deleted.")

def block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_host_until_ready()
return self


class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
class ShardedDeviceArray(xla.DeviceArray):
"""A ShardedDeviceArray is an ndarray sharded across devices.

The purpose of a ShardedDeviceArray is to reduce the number of transfers when
Expand Down Expand Up @@ -346,6 +334,16 @@ def delete(self):
self.device_buffers = None
self._npy_value = None

def _check_if_deleted(self):
if self.device_buffers is None:
raise ValueError("ShardedDeviceArray has been deleted.")

def block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_host_until_ready()
return self

@property
def _value(self):
if self._npy_value is None:
Expand Down Expand Up @@ -757,6 +755,7 @@ def aval(self):
if self.axis_name is not_mapped:
return aval
else:
assert isinstance(aval, ShapedArray)
return ShapedArray(aval.shape[1:], aval.dtype)

def full_lower(self):
Expand Down
9 changes: 6 additions & 3 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,9 @@ def __array__(self, dtype=None, context=None):

__str__ = partialmethod(_forward_to_value, str)
__bool__ = __nonzero__ = partialmethod(_forward_to_value, bool)
__float__ = partialmethod(_forward_to_value, float)
__int__ = partialmethod(_forward_to_value, int)
__complex__ = partialmethod(_forward_to_value, complex)
def __float__(self): return self._value.__float__()
def __int__(self): return self._value.__int__()
def __complex__(self): return self._value.__complex__()
__hex__ = partialmethod(_forward_to_value, hex)
__oct__ = partialmethod(_forward_to_value, oct)
__index__ = partialmethod(_forward_to_value, op.index)
Expand All @@ -841,6 +841,9 @@ def __eq__(self, other): return self._value == other
def __hash__(self):
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")

# The following methods are dynamically overridden in lax_numpy.py.
def __getitem__(self, i): raise NotImplementedError

class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()

Expand Down
9 changes: 5 additions & 4 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import itertools
import operator
import string
from typing import Any
import warnings

import numpy as onp
Expand Down Expand Up @@ -678,7 +679,7 @@ def select(pred, on_true, on_false):
"""
return select_p.bind(pred, on_true, on_false)

def slice(operand, start_indices, limit_indices, strides=None):
def slice(operand: Any, start_indices, limit_indices, strides=None):
"""Wraps XLA's `Slice
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
operator.
Expand Down Expand Up @@ -4180,8 +4181,8 @@ def infeed(token, shape=None):
flat_shapes, treedef = pytree.flatten(shape)
for shape in flat_shapes:
if not isinstance(shape, ShapedArray):
raise TypeError("shapes argument to infeed must be a pytree of "
"ShapedArray values, got {}".format(shapes))
raise TypeError("shape argument to infeed must be a pytree of "
"ShapedArray values, got {}".format(shape))
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes))
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])

Expand Down Expand Up @@ -4387,7 +4388,7 @@ def _dynamic_slice_indices(operand, start_indices):
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
raise ValueError(msg.format(len(start_indices, operand.shape)))
raise ValueError(msg.format(len(start_indices), operand.shape))
# map int over operand.shape to raise any dynamic-shape errors
return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)
for i, d in zip(start_indices, operand.shape)]
Expand Down
12 changes: 7 additions & 5 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,9 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
body_jaxpr=body_jvp_rearranged)

out_carry, out_carry_dot = split_list(out, [num_carry])
out_tangents = iter(out_carry_dot)
out_tangents = [next(out_tangents) if nz else ad_util.zero for nz in nonzeros_out]
out_tangents_iter = iter(out_carry_dot)
out_tangents = [next(out_tangents_iter) if nz else ad_util.zero
for nz in nonzeros_out]
return out_carry, out_tangents

while_p = lax.Primitive('while')
Expand Down Expand Up @@ -701,8 +702,9 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,

carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
tangents_out = iter(carry_dot + ys_dot)
tangents_out = [next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out]
tangents_out_iter = iter(carry_dot + ys_dot)
tangents_out = [next(tangents_out_iter) if nz else ad_util.zero
for nz in nonzeros_out]
return primals_out, tangents_out

def _prune_zeros(ts):
Expand Down Expand Up @@ -919,7 +921,7 @@ def _scan_shape_rule(shapes, forward, length, jaxpr,
num_consts, num_carry, linear):
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_shapes = [tuple(length, *y_aval.shape) for y_aval in y_avals]
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
return init_shexprs + ys_shapes

def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
Expand Down
12 changes: 7 additions & 5 deletions jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,14 @@ def _add_jaxvals_papply_rule(name, size, vals, dims):
xdim, ydim = dims
if xdim == ydim:
out_dim = xdim
elif ydim is None:
y = lax.psplit_like(y, x, name)
out_dim = xdim
else:
x = lax.psplit_like(x, y, name)
out_dim = ydim
raise NotImplementedError
# elif ydim is None:
# y = lax.psplit_like(y, x, name)
# out_dim = xdim
# else:
# x = lax.psplit_like(x, y, name)
# out_dim = ydim
return ad_util.add_jaxvals_p.bind(x, y), out_dim


Expand Down
5 changes: 4 additions & 1 deletion jax/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import namedtuple
import functools
import operator as op
from typing import Any, Callable

import numpy as onp

Expand All @@ -32,7 +33,7 @@
### util

# TODO(mattjj): replace with dataclass when Python 2 support is removed
def taggedtuple(name, fields):
def taggedtuple(name, fields) -> Callable[..., Any]:
"""Lightweight version of namedtuple where equality depends on the type."""
def __new__(cls, *xs):
return tuple.__new__(cls, (cls,) + xs)
Expand Down Expand Up @@ -99,12 +100,14 @@ def __str__(self):
# hash(A(1, 2)) == hash(B(1, 2)) # True
# but we want hashes to be sensitive to the type tag (while still being fast).

# pytype: disable=wrong-arg-count
LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
ArrayVar = taggedtuple('ArrayVar', [])
Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N)
Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye
Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri
Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays
# pytype: enable=wrong-arg-count

def array(shape):
return LazyExpr(ArrayVar(), shape, tuple(range(len(shape))))
Expand Down
2 changes: 1 addition & 1 deletion jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _check_jaxlib_version():


try:
from jaxlib import tpu_client
from jaxlib import tpu_client # pytype: disable=import-error
except:
tpu_client = None
from jaxlib import xla_client
Expand Down
Loading