Skip to content

Commit 1372a56

Browse files
committed
Minimal changes to make Jax pass a pytype check.
1 parent 19fb494 commit 1372a56

25 files changed

+107
-79
lines changed

jax/BUILD

+15-7
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
# JAX is Autograd and XLA
1616

17+
pytype_library = py_library
18+
1719
licenses(["notice"])
1820

1921
package(default_visibility = ["//visibility:public"])
2022

2123
# top-level EF placeholder
2224

23-
py_library(
25+
pytype_library(
2426
name = "jax",
2527
srcs = glob(
2628
[
@@ -43,38 +45,44 @@ py_library(
4345
deps = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
4446
)
4547

46-
py_library(
48+
pytype_library(
4749
name = "stax",
4850
srcs = ["experimental/stax.py"],
51+
srcs_version = "PY3",
4952
deps = [":jax"],
5053
)
5154

52-
py_library(
55+
pytype_library(
5356
name = "optimizers",
5457
srcs = ["experimental/optimizers.py"],
58+
srcs_version = "PY3",
5559
deps = [":jax"],
5660
)
5761

58-
py_library(
62+
pytype_library(
5963
name = "optix",
6064
srcs = ["experimental/optix.py"],
65+
srcs_version = "PY3",
6166
deps = [":jax"],
6267
)
6368

64-
py_library(
69+
pytype_library(
6570
name = "ode",
6671
srcs = ["experimental/ode.py"],
72+
srcs_version = "PY3",
6773
deps = [":jax"],
6874
)
6975

70-
py_library(
76+
pytype_library(
7177
name = "vectorize",
7278
srcs = ["experimental/vectorize.py"],
79+
srcs_version = "PY3",
7380
deps = [":jax"],
7481
)
7582

76-
py_library(
83+
pytype_library(
7784
name = "loops",
7885
srcs = ["experimental/loops.py"],
86+
srcs_version = "PY3",
7987
deps = [":jax"],
8088
)

jax/ad_util.py

-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ def add_impl(xs, ys):
3939
def add_abstract(xs, ys):
4040
return lattice_join(xs, ys)
4141

42-
def zeros_like_impl_jaxtuple(xs):
43-
return JaxTuple(map(zeros_like_impl, xs))
44-
4542
jaxval_zeros_likers = {}
4643

4744
def zeros_like_aval(aval):

jax/api.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def f_pmapped(*args, **kwargs):
938938
msg = ("soft_pmap mapped axis size must be divisble by the number of "
939939
"XLA devices (or be less than or equal to that number), but got "
940940
"an axis size of {} with {} devices.")
941-
raise ValueError(msg.format(axis_size, pxla.pxla.unmapped_device_count()))
941+
raise ValueError(msg.format(axis_size, pxla.unmapped_device_count()))
942942
num_chunks = axis_size // chunk_size
943943

944944
reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
@@ -1922,15 +1922,12 @@ def jaxpr_to_graphviz(jaxpr, consts):
19221922
fragment.extend(map(constant_node, jaxpr.constvars, consts))
19231923

19241924
for eqn in jaxpr.eqns:
1925-
if eqn.destructure:
1926-
id_name = next(id_names)
1927-
fragment.append(function_node(id_name, eqn.primitive.name))
1928-
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
1929-
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
1930-
else:
1931-
fragment.append(function_node(eqn.outvars[0], eqn.primitive.name))
1932-
fragment.extend(edge(invar, eqn.outvars[0]) for invar in eqn.invars)
1933-
fragment.append(outvar_node(jaxpr.outvar, "out"))
1925+
id_name = next(id_names)
1926+
fragment.append(function_node(id_name, eqn.primitive.name))
1927+
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
1928+
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
1929+
for ov in jaxpr.outvars:
1930+
fragment.append(outvar_node(ov, "out"))
19341931
return graph(''.join(fragment))
19351932

19361933
edge = '{} -> {} [color=gray30];\n'.format
@@ -1944,8 +1941,8 @@ def jaxpr_to_graphviz(jaxpr, consts):
19441941
@wraps(fun)
19451942
def graphviz_maker(*args, **kwargs):
19461943
wrapped = lu.wrap_init(fun, kwargs)
1947-
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
1948-
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
1944+
jax_args, in_tree = tree_flatten((args, kwargs))
1945+
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
19491946
pvals = map(pv_like, jax_args)
19501947
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
19511948
return jaxpr_to_graphviz(jaxpr, consts)

jax/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __eq__(self, other):
136136

137137
def __repr__(self):
138138
if self.hash is None:
139-
return 'Literal(val={}, hashable={})'.format(self.val, self.hashable)
139+
return 'Literal(val={})'.format(self.val)
140140
else:
141141
return '{}'.format(self.val)
142142

jax/experimental/loops.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def loop_body(i, acc_arr):
112112
import itertools
113113
import numpy as onp
114114
import traceback
115+
from typing import Any, List, cast
115116

116117
from jax import abstract_arrays
117118
from jax import lax, core
@@ -287,7 +288,8 @@ def __init__(self, scope, loop_builder):
287288
self.loop_builder = loop_builder
288289
self.first_iteration = True # If we are tracing the first iteration
289290
# Stack trace, without this line and the s.range function
290-
self.stack = traceback.StackSummary.from_list(traceback.extract_stack()[:-2])
291+
self.stack = traceback.StackSummary.from_list(
292+
cast(List[Any], traceback.extract_stack()[:-2]))
291293

292294
# Next are state kept from the start of the first iteration to the end of the iteration.
293295
self.carried_state_initial = {}

jax/interpreters/ad.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import functools
2020
import itertools as it
21+
from typing import Any
2122

2223
from . import partial_eval as pe
2324
from .. import core as core
@@ -36,13 +37,14 @@
3637
def identity(x): return x
3738

3839

39-
def jvp(fun, has_aux=False, instantiate=True):
40+
def jvp(fun, has_aux=False, instantiate=True) -> Any:
4041
if not has_aux:
4142
return jvpfun(jvp_subtrace(fun), instantiate)
4243
else:
4344
fun, aux = jvp_subtrace_aux(fun)
4445
return jvpfun(fun, instantiate), aux
4546

47+
4648
@lu.transformation
4749
def jvpfun(instantiate, primals, tangents):
4850
with new_master(JVPTrace) as master:

jax/interpreters/masking.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ def __rsub__(self, other):
142142
return self + -other
143143

144144
def __floordiv__(self, divisor):
145-
q, _ = divmod(self, divisor)
145+
q, _ = divmod(self, divisor) # pytype: disable=wrong-arg-types
146146
return q
147147

148148
def __mod__(self, divisor):
149-
_, r = divmod(self, divisor)
149+
_, r = divmod(self, divisor) # pytype: disable=wrong-arg-types
150150
return r
151151

152152
def __divmod__(self, divisor):

jax/interpreters/pxla.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,7 @@ def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
286286

287287
### lazy device-memory persistence and result handling
288288

289-
class ShardedDeviceValue(xla.DeviceValue):
290-
def _check_if_deleted(self):
291-
if self.device_buffers is None:
292-
raise ValueError("ShardedDeviceValue has been deleted.")
293-
294-
def block_until_ready(self):
295-
self._check_if_deleted()
296-
for buf in self.device_buffers:
297-
buf.block_host_until_ready()
298-
return self
299-
300-
301-
class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
289+
class ShardedDeviceArray(xla.DeviceArray):
302290
"""A ShardedDeviceArray is an ndarray sharded across devices.
303291
304292
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
@@ -346,6 +334,16 @@ def delete(self):
346334
self.device_buffers = None
347335
self._npy_value = None
348336

337+
def _check_if_deleted(self):
338+
if self.device_buffers is None:
339+
raise ValueError("ShardedDeviceArray has been deleted.")
340+
341+
def block_until_ready(self):
342+
self._check_if_deleted()
343+
for buf in self.device_buffers:
344+
buf.block_host_until_ready()
345+
return self
346+
349347
@property
350348
def _value(self):
351349
if self._npy_value is None:
@@ -757,6 +755,7 @@ def aval(self):
757755
if self.axis_name is not_mapped:
758756
return aval
759757
else:
758+
assert isinstance(aval, ShapedArray)
760759
return ShapedArray(aval.shape[1:], aval.dtype)
761760

762761
def full_lower(self):

jax/interpreters/xla.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -825,9 +825,9 @@ def __array__(self, dtype=None, context=None):
825825

826826
__str__ = partialmethod(_forward_to_value, str)
827827
__bool__ = __nonzero__ = partialmethod(_forward_to_value, bool)
828-
__float__ = partialmethod(_forward_to_value, float)
829-
__int__ = partialmethod(_forward_to_value, int)
830-
__complex__ = partialmethod(_forward_to_value, complex)
828+
def __float__(self): return self._value.__float__()
829+
def __int__(self): return self._value.__int__()
830+
def __complex__(self): return self._value.__complex__()
831831
__hex__ = partialmethod(_forward_to_value, hex)
832832
__oct__ = partialmethod(_forward_to_value, oct)
833833
__index__ = partialmethod(_forward_to_value, op.index)
@@ -841,6 +841,9 @@ def __eq__(self, other): return self._value == other
841841
def __hash__(self):
842842
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
843843

844+
# The following methods are dynamically overridden in lax_numpy.py.
845+
def __getitem__(self, i): raise NotImplementedError
846+
844847
class DeletedBuffer(object): pass
845848
deleted_buffer = DeletedBuffer()
846849

jax/lax/lax.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import itertools
2424
import operator
2525
import string
26+
from typing import Any
2627
import warnings
2728

2829
import numpy as onp
@@ -678,7 +679,7 @@ def select(pred, on_true, on_false):
678679
"""
679680
return select_p.bind(pred, on_true, on_false)
680681

681-
def slice(operand, start_indices, limit_indices, strides=None):
682+
def slice(operand: Any, start_indices, limit_indices, strides=None):
682683
"""Wraps XLA's `Slice
683684
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
684685
operator.
@@ -4180,8 +4181,8 @@ def infeed(token, shape=None):
41804181
flat_shapes, treedef = pytree.flatten(shape)
41814182
for shape in flat_shapes:
41824183
if not isinstance(shape, ShapedArray):
4183-
raise TypeError("shapes argument to infeed must be a pytree of "
4184-
"ShapedArray values, got {}".format(shapes))
4184+
raise TypeError("shape argument to infeed must be a pytree of "
4185+
"ShapedArray values, got {}".format(shape))
41854186
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes))
41864187
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
41874188

@@ -4387,7 +4388,7 @@ def _dynamic_slice_indices(operand, start_indices):
43874388
if len(start_indices) != operand.ndim:
43884389
msg = ("Length of slice indices must match number of operand dimensions ({} "
43894390
"vs {})")
4390-
raise ValueError(msg.format(len(start_indices, operand.shape)))
4391+
raise ValueError(msg.format(len(start_indices), operand.shape))
43914392
# map int over operand.shape to raise any dynamic-shape errors
43924393
return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)
43934394
for i, d in zip(start_indices, operand.shape)]

jax/lax/lax_control_flow.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,9 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
362362
body_jaxpr=body_jvp_rearranged)
363363

364364
out_carry, out_carry_dot = split_list(out, [num_carry])
365-
out_tangents = iter(out_carry_dot)
366-
out_tangents = [next(out_tangents) if nz else ad_util.zero for nz in nonzeros_out]
365+
out_tangents_iter = iter(out_carry_dot)
366+
out_tangents = [next(out_tangents_iter) if nz else ad_util.zero
367+
for nz in nonzeros_out]
367368
return out_carry, out_tangents
368369

369370
while_p = lax.Primitive('while')
@@ -701,8 +702,9 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
701702

702703
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
703704
primals_out = carry + ys
704-
tangents_out = iter(carry_dot + ys_dot)
705-
tangents_out = [next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out]
705+
tangents_out_iter = iter(carry_dot + ys_dot)
706+
tangents_out = [next(tangents_out_iter) if nz else ad_util.zero
707+
for nz in nonzeros_out]
706708
return primals_out, tangents_out
707709

708710
def _prune_zeros(ts):
@@ -919,7 +921,7 @@ def _scan_shape_rule(shapes, forward, length, jaxpr,
919921
num_consts, num_carry, linear):
920922
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
921923
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
922-
ys_shapes = [tuple(length, *y_aval.shape) for y_aval in y_avals]
924+
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
923925
return init_shexprs + ys_shapes
924926

925927
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,

jax/lax/lax_parallel.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -619,12 +619,14 @@ def _add_jaxvals_papply_rule(name, size, vals, dims):
619619
xdim, ydim = dims
620620
if xdim == ydim:
621621
out_dim = xdim
622-
elif ydim is None:
623-
y = lax.psplit_like(y, x, name)
624-
out_dim = xdim
625622
else:
626-
x = lax.psplit_like(x, y, name)
627-
out_dim = ydim
623+
raise NotImplementedError
624+
# elif ydim is None:
625+
# y = lax.psplit_like(y, x, name)
626+
# out_dim = xdim
627+
# else:
628+
# x = lax.psplit_like(x, y, name)
629+
# out_dim = ydim
628630
return ad_util.add_jaxvals_p.bind(x, y), out_dim
629631

630632

jax/lazy.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections import namedtuple
2020
import functools
2121
import operator as op
22+
from typing import Any, Callable
2223

2324
import numpy as onp
2425

@@ -32,7 +33,7 @@
3233
### util
3334

3435
# TODO(mattjj): replace with dataclass when Python 2 support is removed
35-
def taggedtuple(name, fields):
36+
def taggedtuple(name, fields) -> Callable[..., Any]:
3637
"""Lightweight version of namedtuple where equality depends on the type."""
3738
def __new__(cls, *xs):
3839
return tuple.__new__(cls, (cls,) + xs)
@@ -99,12 +100,14 @@ def __str__(self):
99100
# hash(A(1, 2)) == hash(B(1, 2)) # True
100101
# but we want hashes to be sensitive to the type tag (while still being fast).
101102

103+
# pytype: disable=wrong-arg-count
102104
LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
103105
ArrayVar = taggedtuple('ArrayVar', [])
104106
Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N)
105107
Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye
106108
Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri
107109
Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays
110+
# pytype: enable=wrong-arg-count
108111

109112
def array(shape):
110113
return LazyExpr(ArrayVar(), shape, tuple(range(len(shape))))

jax/lib/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _check_jaxlib_version():
4545

4646

4747
try:
48-
from jaxlib import tpu_client
48+
from jaxlib import tpu_client # pytype: disable=import-error
4949
except:
5050
tpu_client = None
5151
from jaxlib import xla_client

0 commit comments

Comments
 (0)