Skip to content

Commit d1b0d8a

Browse files
committed
Cleanup Max and Argmax
1 parent 0d12385 commit d1b0d8a

File tree

5 files changed

+105
-182
lines changed

5 files changed

+105
-182
lines changed

pytensor/tensor/elemwise.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
float_dtypes,
3131
lvector,
3232
)
33-
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string
33+
from pytensor.tensor.utils import (
34+
broadcast_static_dim_lengths,
35+
import_func_from_string,
36+
normalize_reduce_axis,
37+
)
3438
from pytensor.tensor.variable import TensorVariable
3539
from pytensor.utils import uniq
3640

@@ -1371,7 +1375,6 @@ def _acc_dtype(self, idtype):
13711375

13721376
def make_node(self, input):
13731377
input = as_tensor_variable(input)
1374-
inp_dims = input.type.ndim
13751378
inp_dtype = input.type.dtype
13761379

13771380
# We need to redefine make_node so that, if self.dtype is None,
@@ -1383,29 +1386,19 @@ def make_node(self, input):
13831386
assert dtype is not None
13841387
assert acc_dtype is not None
13851388

1386-
axis = self.axis
1389+
axis = normalize_reduce_axis(input, self.axis)
13871390

1388-
# scalar inputs are treated as 1D regarding axis in this `Op`
1389-
if axis is not None:
1390-
try:
1391-
axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims))
1392-
except np.AxisError:
1393-
raise np.AxisError(axis, ndim=inp_dims)
1391+
if axis != self.axis or dtype != self.dtype or acc_dtype != self.acc_dtype:
1392+
op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype)
1393+
else:
1394+
op = self
13941395

1396+
if axis is None:
1397+
out_shape = ()
1398+
else:
13951399
out_shape = tuple(
13961400
s for i, s in enumerate(input.type.shape) if i not in axis
13971401
)
1398-
else:
1399-
out_shape = ()
1400-
1401-
if (
1402-
(axis is not None and any(a < 0 for a in axis))
1403-
or dtype != self.dtype
1404-
or acc_dtype != self.acc_dtype
1405-
):
1406-
op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype)
1407-
else:
1408-
op = self
14091402

14101403
output = TensorType(dtype=dtype, shape=out_shape)()
14111404

pytensor/tensor/math.py

Lines changed: 41 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from pytensor import config, printing
1010
from pytensor import scalar as ps
11-
from pytensor.gradient import DisconnectedType
1211
from pytensor.graph.basic import Apply, Variable
1312
from pytensor.graph.op import Op
1413
from pytensor.graph.replace import _vectorize_node
@@ -26,9 +25,9 @@
2625
cast,
2726
concatenate,
2827
constant,
28+
expand_dims,
2929
stack,
3030
switch,
31-
zeros_like,
3231
)
3332
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3433
from pytensor.tensor.elemwise import (
@@ -45,14 +44,11 @@
4544
continuous_dtypes,
4645
discrete_dtypes,
4746
int_dtypes,
48-
integer_dtypes,
4947
tensor,
5048
uint_dtypes,
5149
)
52-
from pytensor.tensor.type_other import NoneConst
53-
from pytensor.tensor.utils import as_list
50+
from pytensor.tensor.utils import as_list, normalize_reduce_axis
5451
from pytensor.tensor.variable import (
55-
TensorConstant,
5652
TensorVariable,
5753
_tensor_py_operators,
5854
)
@@ -157,7 +153,7 @@ class Argmax(COp):
157153

158154
def __init__(self, axis):
159155
if axis is not None:
160-
axis = tuple(axis)
156+
axis = tuple(sorted(axis))
161157
self.axis = axis
162158

163159
def get_params(self, node):
@@ -168,7 +164,7 @@ def get_params(self, node):
168164
c_axis = np.int64(-1)
169165
return self.params_type.get_params(c_axis=c_axis)
170166

171-
def make_node(self, x, axis=None):
167+
def make_node(self, x):
172168
x = as_tensor_variable(x)
173169
if self.axis is None:
174170
all_axes = list(range(x.ndim))
@@ -198,7 +194,9 @@ def perform(self, node, inp, outs):
198194
# Work around
199195
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
200196
# Not-reduced axes in front
201-
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
197+
transposed_x = np.transpose(
198+
x, np.concatenate((keep_axes, np.asarray(axes, dtype="int64")))
199+
)
202200
kept_shape = transposed_x.shape[: len(keep_axes)]
203201
reduced_shape = transposed_x.shape[len(keep_axes) :]
204202
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
@@ -214,7 +212,7 @@ def c_code(self, node, name, inp, out, sub):
214212
if self.axis is None:
215213
axis_code = "axis = NPY_MAXDIMS;"
216214
else:
217-
if len(self.axis) > 1:
215+
if len(self.axis) != 1:
218216
raise NotImplementedError()
219217
# params is only used here for now
220218
axis_code = """
@@ -253,7 +251,7 @@ def c_code(self, node, name, inp, out, sub):
253251
return ret % locals()
254252

255253
def c_code_cache_version(self):
256-
return (1,)
254+
return (2,)
257255

258256
def infer_shape(self, fgraph, node, shapes):
259257
(ishape,) = shapes
@@ -277,7 +275,7 @@ def grad(self, inp, grads):
277275
return [x.zeros_like()]
278276

279277

280-
def argmax(x, axis=None, keepdims=False):
278+
def argmax(x: TensorLike, axis=None, keepdims: bool = False):
281279
"""
282280
Returns indices of maximum elements obtained by iterating over given axis.
283281
@@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False):
286284
287285
Parameters
288286
----------
287+
x: TensorLike
288+
Array on which to compute argmax
289+
axis:
290+
Axis along which to compute argmax. Unlike numpy multiple partial axis are supported.
289291
keepdims : bool
290292
If this is set to True, the axes which are reduced are left in
291293
the result as dimensions with size one. With this option, the result
292294
will broadcast correctly against the original tensor.
293295
296+
Returns
297+
-------
298+
TensorVariable
299+
TensorVariable representing the argmax operation
300+
294301
"""
295-
argout = max_and_argmax(x, axis)[1]
302+
x = as_tensor_variable(x)
303+
axis = normalize_reduce_axis(x, axis)
304+
out = Argmax(axis)(x)
296305

297306
if keepdims:
298-
argout = makeKeepDims(x, argout, axis)
299-
return argout
307+
out = makeKeepDims(x, out, axis)
308+
309+
return out
300310

301311

302312
@_vectorize_node.register(Argmax)
@@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis):
324334
return expand_dims(y, axis)
325335

326336

327-
def check_and_normalize_axes(x, axis):
328-
"""Check axes, normalize and convert them to a Python list of integers.
329-
330-
Parameters
331-
----------
332-
x: TensorVariable
333-
axis: int, tuple or list of integers
334-
335-
Returns
336-
-------
337-
axis: list of integers
338-
Return an empty list if argument is None.
339-
340-
"""
341-
x = as_tensor_variable(x)
342-
if axis is None:
343-
axis = []
344-
elif isinstance(axis, int | np.integer) or (
345-
isinstance(axis, np.ndarray) and axis.ndim == 0
346-
):
347-
axis = [int(axis)]
348-
elif isinstance(axis, tuple | list | np.ndarray):
349-
axis = [int(i) for i in axis]
350-
elif isinstance(axis, Variable):
351-
if NoneConst.equals(axis):
352-
axis = []
353-
elif not isinstance(axis, TensorConstant):
354-
raise TypeError(f"Computation needs a constant axis. Got {axis}")
355-
else:
356-
assert axis.dtype in integer_dtypes
357-
if isinstance(axis.data, int | np.integer) or (
358-
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
359-
):
360-
axis = [int(axis.data)]
361-
elif isinstance(axis.data, list | np.ndarray):
362-
axis = [int(i) for i in axis.data]
363-
else:
364-
raise TypeError(
365-
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
366-
)
367-
if len(axis) > 0:
368-
for i in range(len(axis)):
369-
if axis[i] < 0:
370-
axis[i] += x.type.ndim
371-
if axis[i] < 0 or axis[i] >= x.type.ndim:
372-
raise ValueError(
373-
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
374-
)
375-
axis = list(set(axis))
376-
axis.sort()
377-
return axis
378-
379-
380337
def max_and_argmax(a, axis=None, keepdims=False):
381338
"""
382339
Returns maximum elements and their indices obtained by iterating over
@@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
395352
"""
396353
# Check axis and convert it to a Python list of integers.
397354
# Axis will be used as an op param of Max and Argmax.
398-
a = as_tensor_variable(a)
399-
400-
is_axis_empty = False
401-
if axis == ():
402-
is_axis_empty = True
403-
404-
axis = check_and_normalize_axes(a, axis)
405-
406-
if len(axis) == 0 and not is_axis_empty:
407-
axis = None
408-
409-
out = Max(axis)(a)
410-
411-
if not is_axis_empty:
412-
argout = Argmax(axis)(a)
413-
else:
414-
argout = zeros_like(a, dtype="int64")
415-
416-
if keepdims:
417-
out = makeKeepDims(a, out, axis)
418-
argout = makeKeepDims(a, argout, axis)
419-
return [out, argout]
355+
return [
356+
max(a, axis=axis, keepdims=keepdims),
357+
argmax(a, axis=axis, keepdims=keepdims),
358+
]
420359

421360

422361
class FixedOpCAReduce(CAReduce):
@@ -465,7 +404,7 @@ def clone(self, **kwargs):
465404
axis = kwargs.get("axis", self.axis)
466405
return type(self)(axis=axis)
467406

468-
def grad(self, inp, grads):
407+
def L_op(self, inputs, outputs, grads):
469408
# The strict sense mathematical gradient of the maximum function is
470409
# not calculated here for it is not defined at every point where some
471410
# coordinates are identical. However, since the latter set has null
@@ -479,53 +418,27 @@ def grad(self, inp, grads):
479418
# g_max has one less dimension than x, so you need to complete
480419
# g_max to x's shape when axis=0 the broadcasting mechanism
481420
# does it automatically
482-
x = inp[0]
483-
if self.axis is None:
484-
self.axis = tuple(range(x.ndim))
485-
axis = as_tensor_variable(self.axis)
486-
(g_max,) = grads
487-
488-
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
421+
[x] = inputs
422+
[out] = outputs
423+
[g_out] = grads
489424

490-
# if the op is totally disconnected, so are its inputs
491-
if g_max_disconnected:
492-
return [DisconnectedType()()]
493-
494-
# if NoneConst.equals(axis):
495-
if axis is None:
496-
axis_ = list(range(x.ndim))
497-
else:
498-
axis_ = axis
499-
xmax = max(x, axis_)
500-
501-
# Raise the g_max and xmax to the same number of dim as the input.
502-
pattern = []
503-
out_dim = 0
504-
if NoneConst.equals(axis):
505-
# We are taking the max/argmax over all dimensions.
506-
axis = None
507-
for i in range(x.ndim):
508-
if axis is None or i in axis.data:
509-
pattern.append("x")
510-
else:
511-
pattern.append(out_dim)
512-
out_dim += 1
513-
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
514-
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
425+
axis = tuple(range(x.ndim)) if self.axis is None else self.axis
426+
out_pad = expand_dims(out, axis)
427+
g_out_pad = expand_dims(g_out, axis)
515428

516429
# Set the grad to the correct position.
517-
g_x = eq(xmax_pad, x) * g_max_pad
430+
g_x = eq(out_pad, x) * g_out_pad
518431
return (g_x,)
519432

520433
def R_op(self, inputs, eval_points):
521434
if eval_points[0] is None:
522435
return [None, None]
523436
if len(self.axis) != 1:
524-
raise ValueError("R_op supported for arg_max only for one axis!")
437+
raise ValueError("R_op supported for max only for one axis!")
525438
if self.axis[0] > 1:
526-
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
439+
raise ValueError("R_op supported for max only when axis is 0 or 1")
527440
if inputs[0].ndim != 2:
528-
raise ValueError("R_op supported for arg_max only when input is a matrix")
441+
raise ValueError("R_op supported for max only when input is a matrix")
529442
max_pos = Argmax(self.axis).make_node(*inputs).outputs
530443
# print(eval_points[0].eval())
531444
if self.axis[0] == 0:
@@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False):
564477
We return an error as numpy when we reduce a dim with a shape of 0.
565478
566479
"""
567-
out = max_and_argmax(x, axis)[0]
480+
out = Max(axis=axis)(x)
568481

569482
if keepdims:
570483
out = makeKeepDims(x, out, axis)

pytensor/tensor/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import re
2+
import typing
23
from collections.abc import Sequence
34

45
import numpy as np
6+
from numpy.core.numeric import normalize_axis_tuple
57

68
import pytensor
79
from pytensor.utils import hash_from_code
810

911

12+
if typing.TYPE_CHECKING:
13+
from pytensor.tensor.var import TensorVariable
14+
15+
1016
def hash_from_ndarray(data):
1117
"""
1218
Return a hash from an ndarray.
@@ -222,3 +228,20 @@ def operand_sig(operand_ndim: int, prefix: str) -> str:
222228
operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim)
223229
)
224230
return f"{inputs_sig}->{outputs_sig}"
231+
232+
233+
def normalize_reduce_axis(x: "TensorVariable", axis) -> tuple[int, ...] | None:
234+
"""Normalize the axis parameter for reduce operations."""
235+
if axis is None:
236+
return None
237+
238+
# scalar inputs are treated as 1D regarding axis in reduce operations
239+
x_ndim = x.type.ndim
240+
if axis is not None:
241+
try:
242+
axis = normalize_axis_tuple(axis, ndim=max(1, x_ndim))
243+
except np.AxisError:
244+
raise np.AxisError(axis, ndim=x_ndim)
245+
246+
# TODO: If axis tuple is equivalent to None, return None for more canonicalization?
247+
return axis

0 commit comments

Comments
 (0)