8
8
9
9
from pytensor import config , printing
10
10
from pytensor import scalar as ps
11
- from pytensor .gradient import DisconnectedType
12
11
from pytensor .graph .basic import Apply , Variable
13
12
from pytensor .graph .op import Op
14
13
from pytensor .graph .replace import _vectorize_node
26
25
cast ,
27
26
concatenate ,
28
27
constant ,
28
+ expand_dims ,
29
29
stack ,
30
30
switch ,
31
- zeros_like ,
32
31
)
33
32
from pytensor .tensor .blockwise import Blockwise , vectorize_node_fallback
34
33
from pytensor .tensor .elemwise import (
45
44
continuous_dtypes ,
46
45
discrete_dtypes ,
47
46
int_dtypes ,
48
- integer_dtypes ,
49
47
tensor ,
50
48
uint_dtypes ,
51
49
)
52
- from pytensor .tensor .type_other import NoneConst
53
- from pytensor .tensor .utils import as_list
50
+ from pytensor .tensor .utils import as_list , normalize_reduce_axis
54
51
from pytensor .tensor .variable import (
55
- TensorConstant ,
56
52
TensorVariable ,
57
53
_tensor_py_operators ,
58
54
)
@@ -157,7 +153,7 @@ class Argmax(COp):
157
153
158
154
def __init__ (self , axis ):
159
155
if axis is not None :
160
- axis = tuple (axis )
156
+ axis = tuple (sorted ( axis ) )
161
157
self .axis = axis
162
158
163
159
def get_params (self , node ):
@@ -168,7 +164,7 @@ def get_params(self, node):
168
164
c_axis = np .int64 (- 1 )
169
165
return self .params_type .get_params (c_axis = c_axis )
170
166
171
- def make_node (self , x , axis = None ):
167
+ def make_node (self , x ):
172
168
x = as_tensor_variable (x )
173
169
if self .axis is None :
174
170
all_axes = list (range (x .ndim ))
@@ -198,7 +194,9 @@ def perform(self, node, inp, outs):
198
194
# Work around
199
195
keep_axes = np .array ([i for i in range (x .ndim ) if i not in axes ], dtype = "int64" )
200
196
# Not-reduced axes in front
201
- transposed_x = np .transpose (x , np .concatenate ((keep_axes , axes )))
197
+ transposed_x = np .transpose (
198
+ x , np .concatenate ((keep_axes , np .asarray (axes , dtype = "int64" )))
199
+ )
202
200
kept_shape = transposed_x .shape [: len (keep_axes )]
203
201
reduced_shape = transposed_x .shape [len (keep_axes ) :]
204
202
new_shape = (* kept_shape , np .prod (reduced_shape , dtype = "int64" ))
@@ -214,7 +212,7 @@ def c_code(self, node, name, inp, out, sub):
214
212
if self .axis is None :
215
213
axis_code = "axis = NPY_MAXDIMS;"
216
214
else :
217
- if len (self .axis ) > 1 :
215
+ if len (self .axis ) != 1 :
218
216
raise NotImplementedError ()
219
217
# params is only used here for now
220
218
axis_code = """
@@ -253,7 +251,7 @@ def c_code(self, node, name, inp, out, sub):
253
251
return ret % locals ()
254
252
255
253
def c_code_cache_version (self ):
256
- return (1 ,)
254
+ return (2 ,)
257
255
258
256
def infer_shape (self , fgraph , node , shapes ):
259
257
(ishape ,) = shapes
@@ -277,7 +275,7 @@ def grad(self, inp, grads):
277
275
return [x .zeros_like ()]
278
276
279
277
280
- def argmax (x , axis = None , keepdims = False ):
278
+ def argmax (x : TensorLike , axis = None , keepdims : bool = False ):
281
279
"""
282
280
Returns indices of maximum elements obtained by iterating over given axis.
283
281
@@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False):
286
284
287
285
Parameters
288
286
----------
287
+ x: TensorLike
288
+ Array on which to compute argmax
289
+ axis:
290
+ Axis along which to compute argmax. Unlike numpy multiple partial axis are supported.
289
291
keepdims : bool
290
292
If this is set to True, the axes which are reduced are left in
291
293
the result as dimensions with size one. With this option, the result
292
294
will broadcast correctly against the original tensor.
293
295
296
+ Returns
297
+ -------
298
+ TensorVariable
299
+ TensorVariable representing the argmax operation
300
+
294
301
"""
295
- argout = max_and_argmax (x , axis )[1 ]
302
+ x = as_tensor_variable (x )
303
+ axis = normalize_reduce_axis (x , axis )
304
+ out = Argmax (axis )(x )
296
305
297
306
if keepdims :
298
- argout = makeKeepDims (x , argout , axis )
299
- return argout
307
+ out = makeKeepDims (x , out , axis )
308
+
309
+ return out
300
310
301
311
302
312
@_vectorize_node .register (Argmax )
@@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis):
324
334
return expand_dims (y , axis )
325
335
326
336
327
- def check_and_normalize_axes (x , axis ):
328
- """Check axes, normalize and convert them to a Python list of integers.
329
-
330
- Parameters
331
- ----------
332
- x: TensorVariable
333
- axis: int, tuple or list of integers
334
-
335
- Returns
336
- -------
337
- axis: list of integers
338
- Return an empty list if argument is None.
339
-
340
- """
341
- x = as_tensor_variable (x )
342
- if axis is None :
343
- axis = []
344
- elif isinstance (axis , int | np .integer ) or (
345
- isinstance (axis , np .ndarray ) and axis .ndim == 0
346
- ):
347
- axis = [int (axis )]
348
- elif isinstance (axis , tuple | list | np .ndarray ):
349
- axis = [int (i ) for i in axis ]
350
- elif isinstance (axis , Variable ):
351
- if NoneConst .equals (axis ):
352
- axis = []
353
- elif not isinstance (axis , TensorConstant ):
354
- raise TypeError (f"Computation needs a constant axis. Got { axis } " )
355
- else :
356
- assert axis .dtype in integer_dtypes
357
- if isinstance (axis .data , int | np .integer ) or (
358
- isinstance (axis .data , np .ndarray ) and axis .data .ndim == 0
359
- ):
360
- axis = [int (axis .data )]
361
- elif isinstance (axis .data , list | np .ndarray ):
362
- axis = [int (i ) for i in axis .data ]
363
- else :
364
- raise TypeError (
365
- f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got { axis } "
366
- )
367
- if len (axis ) > 0 :
368
- for i in range (len (axis )):
369
- if axis [i ] < 0 :
370
- axis [i ] += x .type .ndim
371
- if axis [i ] < 0 or axis [i ] >= x .type .ndim :
372
- raise ValueError (
373
- f"Computation needs a valid axis number for { int (x .type .ndim )} -D tensor. Got { int (axis [i ])} "
374
- )
375
- axis = list (set (axis ))
376
- axis .sort ()
377
- return axis
378
-
379
-
380
337
def max_and_argmax (a , axis = None , keepdims = False ):
381
338
"""
382
339
Returns maximum elements and their indices obtained by iterating over
@@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
395
352
"""
396
353
# Check axis and convert it to a Python list of integers.
397
354
# Axis will be used as an op param of Max and Argmax.
398
- a = as_tensor_variable (a )
399
-
400
- is_axis_empty = False
401
- if axis == ():
402
- is_axis_empty = True
403
-
404
- axis = check_and_normalize_axes (a , axis )
405
-
406
- if len (axis ) == 0 and not is_axis_empty :
407
- axis = None
408
-
409
- out = Max (axis )(a )
410
-
411
- if not is_axis_empty :
412
- argout = Argmax (axis )(a )
413
- else :
414
- argout = zeros_like (a , dtype = "int64" )
415
-
416
- if keepdims :
417
- out = makeKeepDims (a , out , axis )
418
- argout = makeKeepDims (a , argout , axis )
419
- return [out , argout ]
355
+ return [
356
+ max (a , axis = axis , keepdims = keepdims ),
357
+ argmax (a , axis = axis , keepdims = keepdims ),
358
+ ]
420
359
421
360
422
361
class FixedOpCAReduce (CAReduce ):
@@ -465,7 +404,7 @@ def clone(self, **kwargs):
465
404
axis = kwargs .get ("axis" , self .axis )
466
405
return type (self )(axis = axis )
467
406
468
- def grad (self , inp , grads ):
407
+ def L_op (self , inputs , outputs , grads ):
469
408
# The strict sense mathematical gradient of the maximum function is
470
409
# not calculated here for it is not defined at every point where some
471
410
# coordinates are identical. However, since the latter set has null
@@ -479,53 +418,27 @@ def grad(self, inp, grads):
479
418
# g_max has one less dimension than x, so you need to complete
480
419
# g_max to x's shape when axis=0 the broadcasting mechanism
481
420
# does it automatically
482
- x = inp [0 ]
483
- if self .axis is None :
484
- self .axis = tuple (range (x .ndim ))
485
- axis = as_tensor_variable (self .axis )
486
- (g_max ,) = grads
487
-
488
- g_max_disconnected = isinstance (g_max .type , DisconnectedType )
421
+ [x ] = inputs
422
+ [out ] = outputs
423
+ [g_out ] = grads
489
424
490
- # if the op is totally disconnected, so are its inputs
491
- if g_max_disconnected :
492
- return [DisconnectedType ()()]
493
-
494
- # if NoneConst.equals(axis):
495
- if axis is None :
496
- axis_ = list (range (x .ndim ))
497
- else :
498
- axis_ = axis
499
- xmax = max (x , axis_ )
500
-
501
- # Raise the g_max and xmax to the same number of dim as the input.
502
- pattern = []
503
- out_dim = 0
504
- if NoneConst .equals (axis ):
505
- # We are taking the max/argmax over all dimensions.
506
- axis = None
507
- for i in range (x .ndim ):
508
- if axis is None or i in axis .data :
509
- pattern .append ("x" )
510
- else :
511
- pattern .append (out_dim )
512
- out_dim += 1
513
- g_max_pad = DimShuffle (g_max .broadcastable , pattern )(g_max )
514
- xmax_pad = DimShuffle (xmax .broadcastable , pattern )(xmax )
425
+ axis = tuple (range (x .ndim )) if self .axis is None else self .axis
426
+ out_pad = expand_dims (out , axis )
427
+ g_out_pad = expand_dims (g_out , axis )
515
428
516
429
# Set the grad to the correct position.
517
- g_x = eq (xmax_pad , x ) * g_max_pad
430
+ g_x = eq (out_pad , x ) * g_out_pad
518
431
return (g_x ,)
519
432
520
433
def R_op (self , inputs , eval_points ):
521
434
if eval_points [0 ] is None :
522
435
return [None , None ]
523
436
if len (self .axis ) != 1 :
524
- raise ValueError ("R_op supported for arg_max only for one axis!" )
437
+ raise ValueError ("R_op supported for max only for one axis!" )
525
438
if self .axis [0 ] > 1 :
526
- raise ValueError ("R_op supported for arg_max only when axis is 0 or 1" )
439
+ raise ValueError ("R_op supported for max only when axis is 0 or 1" )
527
440
if inputs [0 ].ndim != 2 :
528
- raise ValueError ("R_op supported for arg_max only when input is a matrix" )
441
+ raise ValueError ("R_op supported for max only when input is a matrix" )
529
442
max_pos = Argmax (self .axis ).make_node (* inputs ).outputs
530
443
# print(eval_points[0].eval())
531
444
if self .axis [0 ] == 0 :
@@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False):
564
477
We return an error as numpy when we reduce a dim with a shape of 0.
565
478
566
479
"""
567
- out = max_and_argmax ( x , axis )[ 0 ]
480
+ out = Max ( axis = axis )( x )
568
481
569
482
if keepdims :
570
483
out = makeKeepDims (x , out , axis )
0 commit comments