@@ -104,56 +104,11 @@ def batched_dot(a, b):
104104 return batched_dot
105105
106106
107- # @jax_funcify.register(Max)
108- # @jax_funcify.register(Argmax)
109- # def jax_funcify_MaxAndArgmax(op, **kwargs):
110- # axis = op.axis
111-
112- # def maxandargmax(x, axis=axis):
113- # if axis is None:
114- # axes = tuple(range(x.ndim))
115- # else:
116- # axes = tuple(int(ax) for ax in axis)
117-
118- # max_res = jnp.max(x, axis)
119-
120- # # NumPy does not support multiple axes for argmax; this is a
121- # # work-around
122- # keep_axes = jnp.array(
123- # [i for i in range(x.ndim) if i not in axes], dtype="int64"
124- # )
125- # # Not-reduced axes in front
126- # transposed_x = jnp.transpose(
127- # x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128- # )
129- # kept_shape = transposed_x.shape[: len(keep_axes)]
130- # reduced_shape = transposed_x.shape[len(keep_axes) :]
131-
132- # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133- # # Otherwise reshape would complain citing float arg
134- # new_shape = (
135- # *kept_shape,
136- # jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137- # )
138- # reshaped_x = transposed_x.reshape(new_shape)
139-
140- # max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141-
142- # return max_res, max_idx_res
143-
144- # return maxandargmax
145-
146-
147107@jax_funcify .register (Max )
148108def jax_funcify_Max (op , ** kwargs ):
149109 axis = op .axis
150110
151- def max (x , axis = axis ):
152- # if axis is None:
153- # axes = tuple(range(x.ndim))
154- # else:
155- # axes = tuple(int(ax) for ax in axis)
156-
111+ def max (x ):
157112 max_res = jnp .max (x , axis )
158113
159114 return max_res
@@ -165,7 +120,7 @@ def max(x, axis=axis):
165120def jax_funcify_Argmax (op , ** kwargs ):
166121 axis = op .axis
167122
168- def argmax (x , axis = axis ):
123+ def argmax (x ):
169124 if axis is None :
170125 axes = tuple (range (x .ndim ))
171126 else :
0 commit comments