88
99from pytensor import config , printing
1010from pytensor import scalar as ps
11- from pytensor .gradient import DisconnectedType
1211from pytensor .graph .basic import Apply , Variable
1312from pytensor .graph .op import Op
1413from pytensor .graph .replace import _vectorize_node
2625 cast ,
2726 concatenate ,
2827 constant ,
28+ expand_dims ,
2929 stack ,
3030 switch ,
31- zeros_like ,
3231)
3332from pytensor .tensor .blockwise import Blockwise , vectorize_node_fallback
3433from pytensor .tensor .elemwise import (
4544 continuous_dtypes ,
4645 discrete_dtypes ,
4746 int_dtypes ,
48- integer_dtypes ,
4947 tensor ,
5048 uint_dtypes ,
5149)
52- from pytensor .tensor .type_other import NoneConst
53- from pytensor .tensor .utils import as_list
50+ from pytensor .tensor .utils import as_list , normalize_reduce_axis
5451from pytensor .tensor .variable import (
55- TensorConstant ,
5652 TensorVariable ,
5753 _tensor_py_operators ,
5854)
@@ -157,7 +153,7 @@ class Argmax(COp):
157153
158154 def __init__ (self , axis ):
159155 if axis is not None :
160- axis = tuple (axis )
156+ axis = tuple (sorted ( axis ) )
161157 self .axis = axis
162158
163159 def get_params (self , node ):
@@ -168,7 +164,7 @@ def get_params(self, node):
168164 c_axis = np .int64 (- 1 )
169165 return self .params_type .get_params (c_axis = c_axis )
170166
171- def make_node (self , x , axis = None ):
167+ def make_node (self , x ):
172168 x = as_tensor_variable (x )
173169 if self .axis is None :
174170 all_axes = list (range (x .ndim ))
@@ -198,7 +194,9 @@ def perform(self, node, inp, outs):
198194 # Work around
199195 keep_axes = np .array ([i for i in range (x .ndim ) if i not in axes ], dtype = "int64" )
200196 # Not-reduced axes in front
201- transposed_x = np .transpose (x , np .concatenate ((keep_axes , axes )))
197+ transposed_x = np .transpose (
198+ x , np .concatenate ((keep_axes , np .asarray (axes , dtype = "int64" )))
199+ )
202200 kept_shape = transposed_x .shape [: len (keep_axes )]
203201 reduced_shape = transposed_x .shape [len (keep_axes ) :]
204202 new_shape = (* kept_shape , np .prod (reduced_shape , dtype = "int64" ))
@@ -214,7 +212,7 @@ def c_code(self, node, name, inp, out, sub):
214212 if self .axis is None :
215213 axis_code = "axis = NPY_MAXDIMS;"
216214 else :
217- if len (self .axis ) > 1 :
215+ if len (self .axis ) != 1 :
218216 raise NotImplementedError ()
219217 # params is only used here for now
220218 axis_code = """
@@ -253,7 +251,7 @@ def c_code(self, node, name, inp, out, sub):
253251 return ret % locals ()
254252
255253 def c_code_cache_version (self ):
256- return (1 ,)
254+ return (2 ,)
257255
258256 def infer_shape (self , fgraph , node , shapes ):
259257 (ishape ,) = shapes
@@ -277,7 +275,7 @@ def grad(self, inp, grads):
277275 return [x .zeros_like ()]
278276
279277
280- def argmax (x , axis = None , keepdims = False ):
278+ def argmax (x : TensorLike , axis = None , keepdims : bool = False ):
281279 """
282280 Returns indices of maximum elements obtained by iterating over given axis.
283281
@@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False):
286284
287285 Parameters
288286 ----------
287+ x: TensorLike
288+ Array on which to compute argmax
289+ axis:
290+ Axis along which to compute argmax. Unlike numpy multiple partial axis are supported.
289291 keepdims : bool
290292 If this is set to True, the axes which are reduced are left in
291293 the result as dimensions with size one. With this option, the result
292294 will broadcast correctly against the original tensor.
293295
296+ Returns
297+ -------
298+ TensorVariable
299+ TensorVariable representing the argmax operation
300+
294301 """
295- argout = max_and_argmax (x , axis )[1 ]
302+ x = as_tensor_variable (x )
303+ axis = normalize_reduce_axis (axis , ndim = x .type .ndim )
304+ out = Argmax (axis )(x )
296305
297306 if keepdims :
298- argout = makeKeepDims (x , argout , axis )
299- return argout
307+ out = makeKeepDims (x , out , axis )
308+
309+ return out
300310
301311
302312@_vectorize_node .register (Argmax )
@@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis):
324334 return expand_dims (y , axis )
325335
326336
327- def check_and_normalize_axes (x , axis ):
328- """Check axes, normalize and convert them to a Python list of integers.
329-
330- Parameters
331- ----------
332- x: TensorVariable
333- axis: int, tuple or list of integers
334-
335- Returns
336- -------
337- axis: list of integers
338- Return an empty list if argument is None.
339-
340- """
341- x = as_tensor_variable (x )
342- if axis is None :
343- axis = []
344- elif isinstance (axis , int | np .integer ) or (
345- isinstance (axis , np .ndarray ) and axis .ndim == 0
346- ):
347- axis = [int (axis )]
348- elif isinstance (axis , tuple | list | np .ndarray ):
349- axis = [int (i ) for i in axis ]
350- elif isinstance (axis , Variable ):
351- if NoneConst .equals (axis ):
352- axis = []
353- elif not isinstance (axis , TensorConstant ):
354- raise TypeError (f"Computation needs a constant axis. Got { axis } " )
355- else :
356- assert axis .dtype in integer_dtypes
357- if isinstance (axis .data , int | np .integer ) or (
358- isinstance (axis .data , np .ndarray ) and axis .data .ndim == 0
359- ):
360- axis = [int (axis .data )]
361- elif isinstance (axis .data , list | np .ndarray ):
362- axis = [int (i ) for i in axis .data ]
363- else :
364- raise TypeError (
365- f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got { axis } "
366- )
367- if len (axis ) > 0 :
368- for i in range (len (axis )):
369- if axis [i ] < 0 :
370- axis [i ] += x .type .ndim
371- if axis [i ] < 0 or axis [i ] >= x .type .ndim :
372- raise ValueError (
373- f"Computation needs a valid axis number for { int (x .type .ndim )} -D tensor. Got { int (axis [i ])} "
374- )
375- axis = list (set (axis ))
376- axis .sort ()
377- return axis
378-
379-
380337def max_and_argmax (a , axis = None , keepdims = False ):
381338 """
382339 Returns maximum elements and their indices obtained by iterating over
@@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
395352 """
396353 # Check axis and convert it to a Python list of integers.
397354 # Axis will be used as an op param of Max and Argmax.
398- a = as_tensor_variable (a )
399-
400- is_axis_empty = False
401- if axis == ():
402- is_axis_empty = True
403-
404- axis = check_and_normalize_axes (a , axis )
405-
406- if len (axis ) == 0 and not is_axis_empty :
407- axis = None
408-
409- out = Max (axis )(a )
410-
411- if not is_axis_empty :
412- argout = Argmax (axis )(a )
413- else :
414- argout = zeros_like (a , dtype = "int64" )
415-
416- if keepdims :
417- out = makeKeepDims (a , out , axis )
418- argout = makeKeepDims (a , argout , axis )
419- return [out , argout ]
355+ return [
356+ max (a , axis = axis , keepdims = keepdims ),
357+ argmax (a , axis = axis , keepdims = keepdims ),
358+ ]
420359
421360
422361class FixedOpCAReduce (CAReduce ):
@@ -465,7 +404,7 @@ def clone(self, **kwargs):
465404 axis = kwargs .get ("axis" , self .axis )
466405 return type (self )(axis = axis )
467406
468- def grad (self , inp , grads ):
407+ def L_op (self , inputs , outputs , grads ):
469408 # The strict sense mathematical gradient of the maximum function is
470409 # not calculated here for it is not defined at every point where some
471410 # coordinates are identical. However, since the latter set has null
@@ -479,53 +418,27 @@ def grad(self, inp, grads):
479418 # g_max has one less dimension than x, so you need to complete
480419 # g_max to x's shape when axis=0 the broadcasting mechanism
481420 # does it automatically
482- x = inp [0 ]
483- if self .axis is None :
484- self .axis = tuple (range (x .ndim ))
485- axis = as_tensor_variable (self .axis )
486- (g_max ,) = grads
487-
488- g_max_disconnected = isinstance (g_max .type , DisconnectedType )
421+ [x ] = inputs
422+ [out ] = outputs
423+ [g_out ] = grads
489424
490- # if the op is totally disconnected, so are its inputs
491- if g_max_disconnected :
492- return [DisconnectedType ()()]
493-
494- # if NoneConst.equals(axis):
495- if axis is None :
496- axis_ = list (range (x .ndim ))
497- else :
498- axis_ = axis
499- xmax = max (x , axis_ )
500-
501- # Raise the g_max and xmax to the same number of dim as the input.
502- pattern = []
503- out_dim = 0
504- if NoneConst .equals (axis ):
505- # We are taking the max/argmax over all dimensions.
506- axis = None
507- for i in range (x .ndim ):
508- if axis is None or i in axis .data :
509- pattern .append ("x" )
510- else :
511- pattern .append (out_dim )
512- out_dim += 1
513- g_max_pad = DimShuffle (g_max .broadcastable , pattern )(g_max )
514- xmax_pad = DimShuffle (xmax .broadcastable , pattern )(xmax )
425+ axis = tuple (range (x .ndim )) if self .axis is None else self .axis
426+ out_pad = expand_dims (out , axis )
427+ g_out_pad = expand_dims (g_out , axis )
515428
516429 # Set the grad to the correct position.
517- g_x = eq (xmax_pad , x ) * g_max_pad
430+ g_x = eq (out_pad , x ) * g_out_pad
518431 return (g_x ,)
519432
520433 def R_op (self , inputs , eval_points ):
521434 if eval_points [0 ] is None :
522435 return [None , None ]
523436 if len (self .axis ) != 1 :
524- raise ValueError ("R_op supported for arg_max only for one axis!" )
437+ raise ValueError ("R_op supported for max only for one axis!" )
525438 if self .axis [0 ] > 1 :
526- raise ValueError ("R_op supported for arg_max only when axis is 0 or 1" )
439+ raise ValueError ("R_op supported for max only when axis is 0 or 1" )
527440 if inputs [0 ].ndim != 2 :
528- raise ValueError ("R_op supported for arg_max only when input is a matrix" )
441+ raise ValueError ("R_op supported for max only when input is a matrix" )
529442 max_pos = Argmax (self .axis ).make_node (* inputs ).outputs
530443 # print(eval_points[0].eval())
531444 if self .axis [0 ] == 0 :
@@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False):
564477 We return an error as numpy when we reduce a dim with a shape of 0.
565478
566479 """
567- out = max_and_argmax ( x , axis )[ 0 ]
480+ out = Max ( axis = axis )( x )
568481
569482 if keepdims :
570483 out = makeKeepDims (x , out , axis )
0 commit comments