44
44
)
45
45
from pytensor .scalar .basic import add as add_as
46
46
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
47
- from pytensor .tensor .math import Argmax , Max , MulWithoutZeros , Sum
47
+ from pytensor .tensor .math import Argmax , MulWithoutZeros , Sum
48
48
from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
49
49
from pytensor .tensor .type import scalar
50
50
@@ -985,120 +985,6 @@ def log_softmax_py_fn(x):
985
985
return log_softmax
986
986
987
987
988
- # @numba_funcify.register(Max)
989
- # @numba_funcify.register(Argmax)
990
- # # @numba_funcify.register(MaxandArgmax)
991
- # def numba_funcify_MaxAndArgmax(op, node, **kwargs):
992
- # axis = op.axis
993
- # x_at = node.inputs[0]
994
- # x_dtype = x_at.type.numpy_dtype
995
- # x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
996
- # x_ndim = x_at.ndim
997
-
998
- # if x_ndim == 0:
999
-
1000
- # @numba_basic.numba_njit(inline="always")
1001
- # def maxandargmax(x):
1002
- # return x, 0
1003
-
1004
- # else:
1005
- # axes = tuple(int(ax) for ax in axis)
1006
-
1007
- # # NumPy does not support multiple axes for argmax; this is a
1008
- # # work-around
1009
- # keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1010
-
1011
- # reduce_max_py_fn = create_multiaxis_reducer(
1012
- # scalar_maximum,
1013
- # -np.inf,
1014
- # axes,
1015
- # x_ndim,
1016
- # x_dtype,
1017
- # return_scalar=False,
1018
- # )
1019
- # reduce_max = jit_compile_reducer(
1020
- # Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1021
- # reduce_max_py_fn,
1022
- # reduce_to_scalar=False,
1023
- # )
1024
-
1025
- # reduced_x_ndim = x_ndim - len(axes) + 1
1026
- # argmax_axis = create_axis_apply_fn(
1027
- # np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
1028
- # )
1029
-
1030
- # reaxis_order = keep_axes + axes
1031
- # sl1 = slice(None, len(keep_axes))
1032
- # sl2 = slice(len(keep_axes), None)
1033
-
1034
- # @numba_basic.numba_njit
1035
- # def maxandargmax(x):
1036
- # max_res = reduce_max(x)
1037
-
1038
- # # Not-reduced axes in front
1039
- # transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
1040
- # kept_shape = transposed_x.shape[sl1]
1041
- # reduced_shape = transposed_x.shape[sl2]
1042
- # reduced_size = 1
1043
- # for s in reduced_shape:
1044
- # reduced_size *= s
1045
-
1046
- # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
1047
- # # Otherwise reshape would complain citing float arg
1048
- # new_shape = (*kept_shape, reduced_size)
1049
- # reshaped_x = transposed_x.reshape(new_shape)
1050
-
1051
- # max_idx_res = argmax_axis(reshaped_x)
1052
-
1053
- # return max_res, max_idx_res
1054
-
1055
- # return maxandargmax
1056
-
1057
-
1058
- @numba_funcify .register (Max )
1059
- def numba_funcify_Max (op , node , ** kwargs ):
1060
- axis = op .axis
1061
- x_at = node .inputs [0 ]
1062
- x_dtype = x_at .type .numpy_dtype
1063
- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
1064
- x_ndim = x_at .ndim
1065
-
1066
- if x_ndim == 0 :
1067
-
1068
- @numba_basic .numba_njit (inline = "always" )
1069
- def max (x ):
1070
- return x
1071
-
1072
- else :
1073
- axes = tuple (int (ax ) for ax in axis )
1074
-
1075
- # NumPy does not support multiple axes for argmax; this is a
1076
- # work-around
1077
- # keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1078
-
1079
- reduce_max_py_fn = create_multiaxis_reducer (
1080
- scalar_maximum ,
1081
- - np .inf ,
1082
- axes ,
1083
- x_ndim ,
1084
- x_dtype ,
1085
- return_scalar = False ,
1086
- )
1087
- reduce_max = jit_compile_reducer (
1088
- Apply (node .op , node .inputs , [node .outputs [0 ].clone ()]),
1089
- reduce_max_py_fn ,
1090
- reduce_to_scalar = False ,
1091
- )
1092
-
1093
- @numba_basic .numba_njit
1094
- def max (x ):
1095
- max_res = reduce_max (x )
1096
-
1097
- return max_res
1098
-
1099
- return max
1100
-
1101
-
1102
988
@numba_funcify .register (Argmax )
1103
989
def numba_funcify_Argmax (op , node , ** kwargs ):
1104
990
axis = op .axis
@@ -1120,20 +1006,6 @@ def argmax(x):
1120
1006
# work-around
1121
1007
keep_axes = tuple (i for i in range (x_ndim ) if i not in axes )
1122
1008
1123
- # reduce_max_py_fn = create_multiaxis_reducer(
1124
- # scalar_maximum,
1125
- # -np.inf,
1126
- # axes,
1127
- # x_ndim,
1128
- # x_dtype,
1129
- # return_scalar=False,
1130
- # )
1131
- # reduce_max = jit_compile_reducer(
1132
- # Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1133
- # reduce_max_py_fn,
1134
- # reduce_to_scalar=False,
1135
- # )
1136
-
1137
1009
reduced_x_ndim = x_ndim - len (axes ) + 1
1138
1010
argmax_axis = create_axis_apply_fn (
1139
1011
np .argmax , reduced_x_ndim - 1 , reduced_x_ndim , np .int64
0 commit comments