Skip to content

Commit f0460f4

Browse files
Scalar problem solved
1 parent 8e7f626 commit f0460f4

File tree

3 files changed

+9
-138
lines changed

3 files changed

+9
-138
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from pytensor.scalar.basic import add as add_as
4646
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
47-
from pytensor.tensor.math import Argmax, Max, MulWithoutZeros, Sum
47+
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
4848
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4949
from pytensor.tensor.type import scalar
5050

@@ -985,120 +985,6 @@ def log_softmax_py_fn(x):
985985
return log_softmax
986986

987987

988-
# @numba_funcify.register(Max)
989-
# @numba_funcify.register(Argmax)
990-
# # @numba_funcify.register(MaxandArgmax)
991-
# def numba_funcify_MaxAndArgmax(op, node, **kwargs):
992-
# axis = op.axis
993-
# x_at = node.inputs[0]
994-
# x_dtype = x_at.type.numpy_dtype
995-
# x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
996-
# x_ndim = x_at.ndim
997-
998-
# if x_ndim == 0:
999-
1000-
# @numba_basic.numba_njit(inline="always")
1001-
# def maxandargmax(x):
1002-
# return x, 0
1003-
1004-
# else:
1005-
# axes = tuple(int(ax) for ax in axis)
1006-
1007-
# # NumPy does not support multiple axes for argmax; this is a
1008-
# # work-around
1009-
# keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1010-
1011-
# reduce_max_py_fn = create_multiaxis_reducer(
1012-
# scalar_maximum,
1013-
# -np.inf,
1014-
# axes,
1015-
# x_ndim,
1016-
# x_dtype,
1017-
# return_scalar=False,
1018-
# )
1019-
# reduce_max = jit_compile_reducer(
1020-
# Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1021-
# reduce_max_py_fn,
1022-
# reduce_to_scalar=False,
1023-
# )
1024-
1025-
# reduced_x_ndim = x_ndim - len(axes) + 1
1026-
# argmax_axis = create_axis_apply_fn(
1027-
# np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
1028-
# )
1029-
1030-
# reaxis_order = keep_axes + axes
1031-
# sl1 = slice(None, len(keep_axes))
1032-
# sl2 = slice(len(keep_axes), None)
1033-
1034-
# @numba_basic.numba_njit
1035-
# def maxandargmax(x):
1036-
# max_res = reduce_max(x)
1037-
1038-
# # Not-reduced axes in front
1039-
# transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
1040-
# kept_shape = transposed_x.shape[sl1]
1041-
# reduced_shape = transposed_x.shape[sl2]
1042-
# reduced_size = 1
1043-
# for s in reduced_shape:
1044-
# reduced_size *= s
1045-
1046-
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
1047-
# # Otherwise reshape would complain citing float arg
1048-
# new_shape = (*kept_shape, reduced_size)
1049-
# reshaped_x = transposed_x.reshape(new_shape)
1050-
1051-
# max_idx_res = argmax_axis(reshaped_x)
1052-
1053-
# return max_res, max_idx_res
1054-
1055-
# return maxandargmax
1056-
1057-
1058-
@numba_funcify.register(Max)
1059-
def numba_funcify_Max(op, node, **kwargs):
1060-
axis = op.axis
1061-
x_at = node.inputs[0]
1062-
x_dtype = x_at.type.numpy_dtype
1063-
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
1064-
x_ndim = x_at.ndim
1065-
1066-
if x_ndim == 0:
1067-
1068-
@numba_basic.numba_njit(inline="always")
1069-
def max(x):
1070-
return x
1071-
1072-
else:
1073-
axes = tuple(int(ax) for ax in axis)
1074-
1075-
# NumPy does not support multiple axes for argmax; this is a
1076-
# work-around
1077-
# keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1078-
1079-
reduce_max_py_fn = create_multiaxis_reducer(
1080-
scalar_maximum,
1081-
-np.inf,
1082-
axes,
1083-
x_ndim,
1084-
x_dtype,
1085-
return_scalar=False,
1086-
)
1087-
reduce_max = jit_compile_reducer(
1088-
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1089-
reduce_max_py_fn,
1090-
reduce_to_scalar=False,
1091-
)
1092-
1093-
@numba_basic.numba_njit
1094-
def max(x):
1095-
max_res = reduce_max(x)
1096-
1097-
return max_res
1098-
1099-
return max
1100-
1101-
1102988
@numba_funcify.register(Argmax)
1103989
def numba_funcify_Argmax(op, node, **kwargs):
1104990
axis = op.axis
@@ -1120,20 +1006,6 @@ def argmax(x):
11201006
# work-around
11211007
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
11221008

1123-
# reduce_max_py_fn = create_multiaxis_reducer(
1124-
# scalar_maximum,
1125-
# -np.inf,
1126-
# axes,
1127-
# x_ndim,
1128-
# x_dtype,
1129-
# return_scalar=False,
1130-
# )
1131-
# reduce_max = jit_compile_reducer(
1132-
# Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1133-
# reduce_max_py_fn,
1134-
# reduce_to_scalar=False,
1135-
# )
1136-
11371009
reduced_x_ndim = x_ndim - len(axes) + 1
11381010
argmax_axis = create_axis_apply_fn(
11391011
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64

pytensor/tensor/math.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ def perform(self, node, inp, outs):
194194
(max_idx,) = outs
195195
if axes is None:
196196
axes = tuple(range(x.ndim))
197-
else:
198-
axes = tuple(int(ax) for ax in axes)
197+
# else:
198+
# axes = tuple(int(ax) for ax in axes)
199199

200200
# Numpy does not support multiple axes for argmax
201201
# Work around
@@ -402,7 +402,8 @@ def max_and_argmax(a, axis=None, keepdims=False):
402402
a = as_tensor_variable(a)
403403
axis = check_and_normalize_axes(a, axis)
404404
if len(axis) == 0:
405-
axis = list(range(a.type.ndim))
405+
# axis = list(range(a.type.ndim))
406+
axis = None
406407
# out = TensorMax(axis)(a)
407408
out = Max(axis)(a)
408409
argout = Argmax(axis)(a)
@@ -475,6 +476,8 @@ def grad(self, inp, grads):
475476
# g_max to x's shape when axis=0 the broadcasting mechanism
476477
# does it automatically
477478
x = inp[0]
479+
if self.axis is None:
480+
self.axis = tuple(range(x.ndim))
478481
axis = as_tensor_variable(self.axis)
479482
(g_max,) = grads
480483

@@ -617,12 +620,6 @@ def min(x, axis=None, keepdims=False):
617620
elif str_x_type in uint_dtypes:
618621
itype = np.iinfo(x.dtype)
619622
max_val = np.array(itype.max, dtype=itype.dtype)
620-
# print('a')
621-
# for c in (max_val - x):
622-
# print(c.eval())
623-
# print()
624-
# print(max(max_val - x, axis=axis, keepdims=keepdims).eval())
625-
# print()
626623
return max_val - max(max_val - x, axis=axis, keepdims=keepdims)
627624
elif str_x_type == "bool":
628625
return ~max(~x, axis=axis, keepdims=keepdims)

tests/link/numba/test_basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def compare_numba_and_py(
256256
if assert_fn is None:
257257

258258
def assert_fn(x, y):
259+
print(x)
260+
print(y)
259261
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
260262
x, y
261263
)

0 commit comments

Comments
 (0)