Skip to content

Commit 25c8b4b

Browse files
XFAIL pytensor tests for uint64 data type
1 parent 8c29314 commit 25c8b4b

File tree

6 files changed

+17
-64
lines changed

6 files changed

+17
-64
lines changed

pytensor/compile/function/types.py

-1
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,6 @@ def orig_function(
17581758
name=name,
17591759
fgraph=fgraph,
17601760
)
1761-
print(m)
17621761
with config.change_flags(compute_test_value="off"):
17631762
fn = m.create(defaults)
17641763
finally:

pytensor/link/jax/dispatch/nlinalg.py

-45
Original file line numberDiff line numberDiff line change
@@ -104,56 +104,11 @@ def batched_dot(a, b):
104104
return batched_dot
105105

106106

107-
# @jax_funcify.register(Max)
108-
# @jax_funcify.register(Argmax)
109-
# def jax_funcify_MaxAndArgmax(op, **kwargs):
110-
# axis = op.axis
111-
112-
# def maxandargmax(x, axis=axis):
113-
# if axis is None:
114-
# axes = tuple(range(x.ndim))
115-
# else:
116-
# axes = tuple(int(ax) for ax in axis)
117-
118-
# max_res = jnp.max(x, axis)
119-
120-
# # NumPy does not support multiple axes for argmax; this is a
121-
# # work-around
122-
# keep_axes = jnp.array(
123-
# [i for i in range(x.ndim) if i not in axes], dtype="int64"
124-
# )
125-
# # Not-reduced axes in front
126-
# transposed_x = jnp.transpose(
127-
# x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128-
# )
129-
# kept_shape = transposed_x.shape[: len(keep_axes)]
130-
# reduced_shape = transposed_x.shape[len(keep_axes) :]
131-
132-
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133-
# # Otherwise reshape would complain citing float arg
134-
# new_shape = (
135-
# *kept_shape,
136-
# jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137-
# )
138-
# reshaped_x = transposed_x.reshape(new_shape)
139-
140-
# max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141-
142-
# return max_res, max_idx_res
143-
144-
# return maxandargmax
145-
146-
147107
@jax_funcify.register(Max)
148108
def jax_funcify_Max(op, **kwargs):
149109
axis = op.axis
150110

151111
def max(x, axis=axis):
152-
# if axis is None:
153-
# axes = tuple(range(x.ndim))
154-
# else:
155-
# axes = tuple(int(ax) for ax in axis)
156-
157112
max_res = jnp.max(x, axis)
158113

159114
return max_res

pytensor/tensor/math.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@
107107
float64_atol = 1e-8
108108

109109

110+
def __getattr__(name):
111+
if name == "MaxandArgmax":
112+
warnings.warn(
113+
"The class `MaxandArgmax` has been deprecated. "
114+
"Call `Max` and `Argmax` seperately as an alternative.",
115+
FutureWarning,
116+
)
117+
118+
110119
def _get_atol_rtol(a, b):
111120
tiny = ("float16",)
112121
narrow = ("float32", "complex64")
@@ -134,15 +143,6 @@ def _allclose(a, b, rtol=None, atol=None):
134143
return np.allclose(a, b, atol=atol_, rtol=rtol_)
135144

136145

137-
def __getattr__(name):
138-
if name == "MaxandArgmax":
139-
warnings.warn(
140-
"The class `MaxandArgmax` has been deprecated. "
141-
"Call `Max` and `Argmax` seperately as an alternative.",
142-
FutureWarning,
143-
)
144-
145-
146146
class Argmax(COp):
147147
"""
148148
Calculate the argmax over a given axis or over all axes.
@@ -398,21 +398,21 @@ def max_and_argmax(a, axis=None, keepdims=False):
398398
399399
"""
400400
# Check axis and convert it to a Python list of integers.
401-
# Axis will be used as an op param of MaxAndArgmax.
401+
# Axis will be used as an op param of Max and Argmax.
402402
a = as_tensor_variable(a)
403403

404-
flag = False
404+
is_axis_empty = False
405405
if axis == ():
406-
flag = True
406+
is_axis_empty = True
407407

408408
axis = check_and_normalize_axes(a, axis)
409409

410-
if len(axis) == 0 and not flag:
410+
if len(axis) == 0 and not is_axis_empty:
411411
axis = None
412412

413413
out = Max(axis)(a)
414414

415-
if not flag:
415+
if not is_axis_empty:
416416
argout = Argmax(axis)(a)
417417
else:
418418
argout = zeros_like(a, dtype="int64")
@@ -495,7 +495,8 @@ def grad(self, inp, grads):
495495
if g_max_disconnected:
496496
return [DisconnectedType()()]
497497

498-
if NoneConst.equals(axis):
498+
# if NoneConst.equals(axis):
499+
if axis is None:
499500
axis_ = list(range(x.ndim))
500501
else:
501502
axis_ = axis

pytensor/tensor/rewriting/uncanonicalize.py

-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def local_max_to_min(fgraph, node):
5757
"""
5858
if node.op == neg and node.inputs[0].owner:
5959
max = node.inputs[0]
60-
# print(max.owner.op.scalar_op)
6160
if (
6261
max.owner
6362
and isinstance(max.owner.op, CAReduce)

tests/link/numba/test_basic.py

-2
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ def compare_numba_and_py(
256256
if assert_fn is None:
257257

258258
def assert_fn(x, y):
259-
print(x)
260-
print(y)
261259
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
262260
x, y
263261
)

tests/tensor/test_math.py

+1
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,7 @@ def _grad_list(self):
13941394
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
13951395
# axis=1)[0], n)),axis=1)
13961396

1397+
@pytest.mark.xfail(reason="Fails due to #770")
13971398
def test_uint(self):
13981399
for dtype in ("uint8", "uint16", "uint32", "uint64"):
13991400
itype = np.iinfo(dtype)

0 commit comments

Comments
 (0)