Skip to content

Commit 3c2d26e

Browse files
author
Saurabh Singh
committed
fix
1 parent 3372ca0 commit 3c2d26e

File tree

1 file changed

+22
-36
lines changed

1 file changed

+22
-36
lines changed

keras/src/backend/openvino/numpy.py

+22-36
Original file line numberDiff line numberDiff line change
@@ -988,83 +988,69 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
988988
)
989989

990990

991-
def maximum(x1, x2):
992-
x1 = get_ov_output(x1)
993-
x2 = get_ov_output(x2)
994-
x1, x2 = _align_operand_types(x1, x2, "maximum()")
995-
return OpenVINOKerasTensor(ov_opset.maximum(x1, x2).output(0))
996-
997-
998991
def median(x, axis=None, keepdims=False):
999992
x = get_ov_output(x)
1000993
original_type = x.get_element_type()
1001994
is_bool = original_type == Type.boolean
1002-
1003995
if is_bool:
1004996
x = ov_opset.convert(x, Type.i32).output(0)
1005997
elif original_type not in (Type.f32, Type.f64):
1006998
x = ov_opset.convert(x, Type.f32).output(0)
1007-
1008999
if axis is None:
10091000
x = ov_opset.reshape(x, ov_opset.constant([-1], Type.i32).output(0), False).output(0)
10101001
axis = 0
1011-
10121002
shape = ov_opset.convert(ov_opset.shape_of(x).output(0), Type.i64).output(0)
1003+
rank = x.get_partial_shape().rank.get_length()
10131004
axis_const = ov_opset.constant([axis], Type.i32).output(0)
10141005
axis_length = ov_opset.reshape(
10151006
ov_opset.gather(shape, axis_const, 0).output(0),
10161007
ov_opset.constant([], Type.i32).output(0),
10171008
False
10181009
).output(0)
1019-
10201010
const_zero = ov_opset.constant(0, Type.i64).output(0)
10211011
is_empty = ov_opset.equal(axis_length, const_zero).output(0)
10221012
zero_value = ov_opset.constant(0.0 if not is_bool else 0, x.get_element_type()).output(0)
1023-
1024-
result_shape = shape
1025-
if keepdims:
1026-
result_shape = ov_opset.select(
1027-
is_empty,
1028-
shape,
1029-
ov_opset.scatter_elements_update(shape, axis_const, ov_opset.constant([1], Type.i64).output(0), 0).output(0)
1030-
).output(0)
1031-
elif axis is None and x.get_partial_shape().rank.get_length() > 1:
1032-
result_shape = ov_opset.constant([], Type.i32).output(0)
1033-
1013+
if axis is None:
1014+
if keepdims:
1015+
result_shape = ov_opset.constant([1] * rank, Type.i32).output(0)
1016+
else:
1017+
result_shape = ov_opset.constant([], Type.i32).output(0)
1018+
else:
1019+
if keepdims:
1020+
one_i64 = ov_opset.constant([1], Type.i64).output(0)
1021+
result_shape = ov_opset.scatter_elements_update(shape, axis_const, one_i64, 0).output(0)
1022+
else:
1023+
kept_axes = [i for i in range(rank) if i != axis]
1024+
kept_const = ov_opset.constant(kept_axes, Type.i32).output(0)
1025+
result_shape = ov_opset.gather(shape, kept_const, 0).output(0)
10341026
empty_result = ov_opset.reshape(zero_value, result_shape, False).output(0)
1035-
10361027
sorted_values = ov_opset.topk(x, axis_length, axis, "min", "value").output(0)
1037-
10381028
const_one = ov_opset.constant(1, Type.i64).output(0)
1039-
is_odd = ov_opset.equal(ov_opset.floor_mod(axis_length, ov_opset.constant(2, Type.i64).output(0)).output(0), const_one).output(0)
1040-
1029+
mod_two = ov_opset.floor_mod(axis_length, ov_opset.constant(2, Type.i64).output(0)).output(0)
1030+
is_odd = ov_opset.equal(mod_two, const_one).output(0)
10411031
half = ov_opset.floor(ov_opset.divide(axis_length, ov_opset.constant(2, Type.i64).output(0)).output(0)).output(0)
10421032
half = ov_opset.convert(half, Type.i64).output(0)
1043-
10441033
mid_index = ov_opset.convert(half, Type.i32).output(0)
10451034
prev_index = ov_opset.convert(ov_opset.subtract(half, const_one).output(0), Type.i32).output(0)
1046-
10471035
mid_elem = ov_opset.gather(sorted_values, mid_index, axis).output(0)
10481036
prev_elem = ov_opset.gather(sorted_values, prev_index, axis).output(0)
1049-
10501037
if is_bool:
10511038
sum_middle = ov_opset.add(mid_elem, prev_elem).output(0)
10521039
is_two = ov_opset.equal(sum_middle, ov_opset.constant(2, Type.i32).output(0)).output(0)
10531040
is_one = ov_opset.equal(sum_middle, ov_opset.constant(1, Type.i32).output(0)).output(0)
1054-
even_result = ov_opset.select(is_two, ov_opset.constant(1, Type.i32).output(0),
1055-
ov_opset.select(is_one, ov_opset.constant(1, Type.i32).output(0),
1056-
ov_opset.constant(0, Type.i32).output(0))).output(0)
1041+
even_result = ov_opset.select(
1042+
is_two,
1043+
ov_opset.constant(1, Type.i32).output(0),
1044+
ov_opset.select(is_one, ov_opset.constant(1, Type.i32).output(0), ov_opset.constant(0, Type.i32).output(0))
1045+
).output(0)
10571046
else:
10581047
even_result = ov_opset.divide(
10591048
ov_opset.add(mid_elem, prev_elem).output(0),
10601049
ov_opset.constant(2.0, x.get_element_type()).output(0)
10611050
).output(0)
1062-
10631051
median_result = ov_opset.select(is_odd, mid_elem, even_result).output(0)
1064-
1065-
if keepdims or (axis is None and x.get_partial_shape().rank.get_length() > 1):
1052+
if keepdims or (axis is None and rank > 1):
10661053
median_result = ov_opset.reshape(median_result, result_shape, False).output(0)
1067-
10681054
final_result = ov_opset.select(is_empty, empty_result, median_result).output(0)
10691055
return OpenVINOKerasTensor(ov_opset.convert(final_result, original_type).output(0))
10701056

0 commit comments

Comments
 (0)