@@ -104,56 +104,11 @@ def batched_dot(a, b):
104
104
return batched_dot
105
105
106
106
107
- # @jax_funcify.register(Max)
108
- # @jax_funcify.register(Argmax)
109
- # def jax_funcify_MaxAndArgmax(op, **kwargs):
110
- # axis = op.axis
111
-
112
- # def maxandargmax(x, axis=axis):
113
- # if axis is None:
114
- # axes = tuple(range(x.ndim))
115
- # else:
116
- # axes = tuple(int(ax) for ax in axis)
117
-
118
- # max_res = jnp.max(x, axis)
119
-
120
- # # NumPy does not support multiple axes for argmax; this is a
121
- # # work-around
122
- # keep_axes = jnp.array(
123
- # [i for i in range(x.ndim) if i not in axes], dtype="int64"
124
- # )
125
- # # Not-reduced axes in front
126
- # transposed_x = jnp.transpose(
127
- # x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128
- # )
129
- # kept_shape = transposed_x.shape[: len(keep_axes)]
130
- # reduced_shape = transposed_x.shape[len(keep_axes) :]
131
-
132
- # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133
- # # Otherwise reshape would complain citing float arg
134
- # new_shape = (
135
- # *kept_shape,
136
- # jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137
- # )
138
- # reshaped_x = transposed_x.reshape(new_shape)
139
-
140
- # max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141
-
142
- # return max_res, max_idx_res
143
-
144
- # return maxandargmax
145
-
146
-
147
107
@jax_funcify .register (Max )
148
108
def jax_funcify_Max (op , ** kwargs ):
149
109
axis = op .axis
150
110
151
111
def max (x , axis = axis ):
152
- # if axis is None:
153
- # axes = tuple(range(x.ndim))
154
- # else:
155
- # axes = tuple(int(ax) for ax in axis)
156
-
157
112
max_res = jnp .max (x , axis )
158
113
159
114
return max_res
0 commit comments