Skip to content

Commit 1d9a484

Browse files
Finalise chnages to seperate MaxAndArgmax Op
1 parent 98266c8 commit 1d9a484

File tree

3 files changed

+43
-39
lines changed

3 files changed

+43
-39
lines changed

pytensor/tensor/math.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
constant,
2929
stack,
3030
switch,
31+
zeros_like,
3132
)
3233
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3334
from pytensor.tensor.elemwise import (
@@ -192,10 +193,9 @@ def perform(self, node, inp, outs):
192193
(x,) = inp
193194
axes = self.axis
194195
(max_idx,) = outs
196+
195197
if axes is None:
196198
axes = tuple(range(x.ndim))
197-
# else:
198-
# axes = tuple(int(ax) for ax in axes)
199199

200200
# Numpy does not support multiple axes for argmax
201201
# Work around
@@ -400,14 +400,22 @@ def max_and_argmax(a, axis=None, keepdims=False):
400400
# Check axis and convert it to a Python list of integers.
401401
# Axis will be used as an op param of MaxAndArgmax.
402402
a = as_tensor_variable(a)
403+
404+
flag = False
405+
if axis == ():
406+
flag = True
407+
403408
axis = check_and_normalize_axes(a, axis)
404-
if len(axis) == 0:
405-
# axis = list(range(a.type.ndim))
409+
410+
if len(axis) == 0 and not flag:
406411
axis = None
407-
# out = TensorMax(axis)(a)
412+
408413
out = Max(axis)(a)
409-
argout = Argmax(axis)(a)
410-
# _, argout = MaxAndArgmax(axis)(a)
414+
415+
if not flag:
416+
argout = Argmax(axis)(a)
417+
else:
418+
argout = zeros_like(a, dtype="int64")
411419

412420
if keepdims:
413421
out = makeKeepDims(a, out, axis)
@@ -483,7 +491,7 @@ def grad(self, inp, grads):
483491

484492
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
485493

486-
# if the op is tota lly disconnected, so are its inputs
494+
# if the op is totally disconnected, so are its inputs
487495
if g_max_disconnected:
488496
return [DisconnectedType()()]
489497

tests/link/numba/test_elemwise.py

-18
Original file line numberDiff line numberDiff line change
@@ -552,24 +552,6 @@ def test_LogSoftmax(x, axis, exc):
552552
),
553553
],
554554
)
555-
# def test_MaxAndArgmax(x, axes, exc):
556-
# g = ptm.MaxAndArgmax(axes)(x)
557-
558-
# if isinstance(g, list):
559-
# g_fg = FunctionGraph(outputs=g)
560-
# else:
561-
# g_fg = FunctionGraph(outputs=[g])
562-
563-
# cm = contextlib.suppress() if exc is None else pytest.warns(exc)
564-
# with cm:
565-
# compare_numba_and_py(
566-
# g_fg,
567-
# [
568-
# i.tag.test_value
569-
# for i in g_fg.inputs
570-
# if not isinstance(i, SharedVariable | Constant)
571-
# ],
572-
# )
573555
def test_Max(x, axes, exc):
574556
g = ptm.Max(axes)(x)
575557

tests/tensor/test_math.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,7 @@ def setup_method(self):
764764
Max.debug = 0
765765
Argmax.debug = 0
766766

767-
def test_basic_0(self):
768-
# dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0)
767+
def test_basic(self):
769768
n = as_tensor_variable(5)
770769
v, i = eval_outputs(max_and_argmax(n, axis=()))
771770
assert v == 5.0
@@ -1040,6 +1039,29 @@ def test_vectorize(self, core_axis, batch_axis):
10401039
assert isinstance(new_node.op, Argmax)
10411040
assert new_node.op.axis == batch_axis
10421041

1042+
def test_max_empty_axis(self):
1043+
x = np.random.normal(size=(2, 3, 5, 7))
1044+
axis = ()
1045+
1046+
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
1047+
shape_axis = tuple(x.shape[dim] for dim in axis)
1048+
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
1049+
x_transposed = x.transpose(*axis, *non_axis)
1050+
1051+
x_axis_raveled = x_transposed.reshape(
1052+
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
1053+
)
1054+
max_x = max_and_argmax(x, axis=axis)[0].eval()
1055+
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
1056+
1057+
raveled_max = x_axis_raveled[
1058+
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
1059+
]
1060+
indirect_max = raveled_max.reshape(shape_non_axis)
1061+
1062+
np.testing.assert_allclose(max_x, x.max(axis=axis))
1063+
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
1064+
10431065

10441066
class TestArgminArgmax:
10451067
def setup_method(self):
@@ -1379,19 +1401,11 @@ def test_uint(self):
13791401
data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype)
13801402
n = as_tensor_variable(data)
13811403
assert min(n).dtype == dtype
1382-
# print(min(n).owner.inputs[1].acc_dtype)
13831404
i = eval_outputs(min(n))
1384-
# pytensor.dprint(n)
1385-
# for x in n:
1386-
# print(x.eval())
1387-
print(i)
1388-
print(itype.min)
1389-
print()
13901405
assert i == itype.min
1391-
# assert max(n).dtype == dtype
1392-
# i = eval_outputs(max(n))
1393-
# assert i == itype.max
1394-
# assert 0
1406+
assert max(n).dtype == dtype
1407+
i = eval_outputs(max(n))
1408+
assert i == itype.max
13951409

13961410
def test_bool(self):
13971411
data = np.array([True, False], "bool")

0 commit comments

Comments
 (0)