Skip to content

Commit c992d3a

Browse files
XFAIL pytensor tests for uint64 data type
1 parent 8c29314 commit c992d3a

File tree

9 files changed

+25
-76
lines changed

9 files changed

+25
-76
lines changed

pytensor/compile/function/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def opt_log1p(node):
312312
else:
313313
# note: pfunc will also call orig_function -- orig_function is
314314
# a choke point that all compilation must pass through
315-
316315
fn = pfunc(
317316
params=inputs,
318317
outputs=outputs,

pytensor/compile/function/types.py

-1
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,6 @@ def orig_function(
17581758
name=name,
17591759
fgraph=fgraph,
17601760
)
1761-
print(m)
17621761
with config.change_flags(compute_test_value="off"):
17631762
fn = m.create(defaults)
17641763
finally:

pytensor/graph/op.py

-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def __call__(
291291
292292
"""
293293
node = self.make_node(*inputs, **kwargs)
294-
295294
if name is not None:
296295
if len(node.outputs) == 1:
297296
node.outputs[0].name = name

pytensor/link/jax/dispatch/nlinalg.py

+2-47
Original file line numberDiff line numberDiff line change
@@ -104,56 +104,11 @@ def batched_dot(a, b):
104104
return batched_dot
105105

106106

107-
# @jax_funcify.register(Max)
108-
# @jax_funcify.register(Argmax)
109-
# def jax_funcify_MaxAndArgmax(op, **kwargs):
110-
# axis = op.axis
111-
112-
# def maxandargmax(x, axis=axis):
113-
# if axis is None:
114-
# axes = tuple(range(x.ndim))
115-
# else:
116-
# axes = tuple(int(ax) for ax in axis)
117-
118-
# max_res = jnp.max(x, axis)
119-
120-
# # NumPy does not support multiple axes for argmax; this is a
121-
# # work-around
122-
# keep_axes = jnp.array(
123-
# [i for i in range(x.ndim) if i not in axes], dtype="int64"
124-
# )
125-
# # Not-reduced axes in front
126-
# transposed_x = jnp.transpose(
127-
# x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128-
# )
129-
# kept_shape = transposed_x.shape[: len(keep_axes)]
130-
# reduced_shape = transposed_x.shape[len(keep_axes) :]
131-
132-
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133-
# # Otherwise reshape would complain citing float arg
134-
# new_shape = (
135-
# *kept_shape,
136-
# jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137-
# )
138-
# reshaped_x = transposed_x.reshape(new_shape)
139-
140-
# max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141-
142-
# return max_res, max_idx_res
143-
144-
# return maxandargmax
145-
146-
147107
@jax_funcify.register(Max)
148108
def jax_funcify_Max(op, **kwargs):
149109
axis = op.axis
150110

151-
def max(x, axis=axis):
152-
# if axis is None:
153-
# axes = tuple(range(x.ndim))
154-
# else:
155-
# axes = tuple(int(ax) for ax in axis)
156-
111+
def max(x):
157112
max_res = jnp.max(x, axis)
158113

159114
return max_res
@@ -165,7 +120,7 @@ def max(x, axis=axis):
165120
def jax_funcify_Argmax(op, **kwargs):
166121
axis = op.axis
167122

168-
def argmax(x, axis=axis):
123+
def argmax(x):
169124
if axis is None:
170125
axes = tuple(range(x.ndim))
171126
else:

pytensor/tensor/math.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@
107107
float64_atol = 1e-8
108108

109109

110+
def __getattr__(name):
111+
if name == "MaxandArgmax":
112+
raise AttributeError(
113+
"The class `MaxandArgmax` has been deprecated. "
114+
"Call `Max` and `Argmax` separately as an alternative."
115+
)
116+
117+
110118
def _get_atol_rtol(a, b):
111119
tiny = ("float16",)
112120
narrow = ("float32", "complex64")
@@ -134,15 +142,6 @@ def _allclose(a, b, rtol=None, atol=None):
134142
return np.allclose(a, b, atol=atol_, rtol=rtol_)
135143

136144

137-
def __getattr__(name):
138-
if name == "MaxandArgmax":
139-
warnings.warn(
140-
"The class `MaxandArgmax` has been deprecated. "
141-
"Call `Max` and `Argmax` seperately as an alternative.",
142-
FutureWarning,
143-
)
144-
145-
146145
class Argmax(COp):
147146
"""
148147
Calculate the argmax over a given axis or over all axes.
@@ -193,10 +192,8 @@ def perform(self, node, inp, outs):
193192
(x,) = inp
194193
axes = self.axis
195194
(max_idx,) = outs
196-
197195
if axes is None:
198196
axes = tuple(range(x.ndim))
199-
200197
# Numpy does not support multiple axes for argmax
201198
# Work around
202199
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
@@ -398,21 +395,21 @@ def max_and_argmax(a, axis=None, keepdims=False):
398395
399396
"""
400397
# Check axis and convert it to a Python list of integers.
401-
# Axis will be used as an op param of MaxAndArgmax.
398+
# Axis will be used as an op param of Max and Argmax.
402399
a = as_tensor_variable(a)
403400

404-
flag = False
401+
is_axis_empty = False
405402
if axis == ():
406-
flag = True
403+
is_axis_empty = True
407404

408405
axis = check_and_normalize_axes(a, axis)
409406

410-
if len(axis) == 0 and not flag:
407+
if len(axis) == 0 and not is_axis_empty:
411408
axis = None
412409

413410
out = Max(axis)(a)
414411

415-
if not flag:
412+
if not is_axis_empty:
416413
argout = Argmax(axis)(a)
417414
else:
418415
argout = zeros_like(a, dtype="int64")
@@ -495,7 +492,8 @@ def grad(self, inp, grads):
495492
if g_max_disconnected:
496493
return [DisconnectedType()()]
497494

498-
if NoneConst.equals(axis):
495+
# if NoneConst.equals(axis):
496+
if axis is None:
499497
axis_ = list(range(x.ndim))
500498
else:
501499
axis_ = axis

pytensor/tensor/rewriting/uncanonicalize.py

-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def local_max_to_min(fgraph, node):
5757
"""
5858
if node.op == neg and node.inputs[0].owner:
5959
max = node.inputs[0]
60-
# print(max.owner.op.scalar_op)
6160
if (
6261
max.owner
6362
and isinstance(max.owner.op, CAReduce)

tests/link/numba/test_basic.py

-2
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ def compare_numba_and_py(
256256
if assert_fn is None:
257257

258258
def assert_fn(x, y):
259-
print(x)
260-
print(y)
261259
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
262260
x, y
263261
)

tests/tensor/test_math.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,11 @@ def test_basic_2(self, axis, np_axis):
808808
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
809809
assert tuple(v_shape) == vt.type.shape
810810
assert tuple(i_shape) == it.type.shape
811-
# Test valuesgi
811+
# Test values
812+
v, i = eval_outputs([vt, it])
813+
assert i.dtype == "int64"
814+
assert np.all(v == np_max)
815+
assert np.all(i == np_argm)
812816

813817
@pytest.mark.parametrize(
814818
"axis,np_axis",
@@ -1029,9 +1033,7 @@ def test_vectorize(self, core_axis, batch_axis):
10291033
batch_x = tensor(shape=(3, 5, 5, 5, 5))
10301034

10311035
# Test MaxAndArgmax
1032-
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
1033-
max_node = max_x.owner
1034-
assert isinstance(max_node.op, Max)
1036+
argmax_x = argmax(x, axis=core_axis)
10351037

10361038
arg_max_node = argmax_x.owner
10371039
new_node = vectorize_node(arg_max_node, batch_x)
@@ -1394,6 +1396,7 @@ def _grad_list(self):
13941396
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
13951397
# axis=1)[0], n)),axis=1)
13961398

1399+
@pytest.mark.xfail(reason="Fails due to #770")
13971400
def test_uint(self):
13981401
for dtype in ("uint8", "uint16", "uint32", "uint64"):
13991402
itype = np.iinfo(dtype)
@@ -1420,7 +1423,7 @@ def test_bool(self):
14201423

14211424

14221425
def test_MaxandArgmax_deprecated():
1423-
with pytest.warns(FutureWarning, match=".*deprecated.*"):
1426+
with pytest.raises(AttributeError):
14241427
pytensor.tensor.math.MaxandArgmax
14251428

14261429

tests/tensor/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def inplace_func(
107107
def eval_outputs(outputs, ops=(), mode=None):
108108
f = inplace_func([], outputs, mode=mode)
109109
variables = f()
110-
111110
if ops:
112111
assert any(isinstance(node.op, ops) for node in f.maker.fgraph.apply_nodes)
113112
if isinstance(variables, tuple | list) and len(variables) == 1:

0 commit comments

Comments
 (0)