@@ -988,83 +988,69 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
988
988
)
989
989
990
990
991
- def maximum (x1 , x2 ):
992
- x1 = get_ov_output (x1 )
993
- x2 = get_ov_output (x2 )
994
- x1 , x2 = _align_operand_types (x1 , x2 , "maximum()" )
995
- return OpenVINOKerasTensor (ov_opset .maximum (x1 , x2 ).output (0 ))
996
-
997
-
998
991
def median (x , axis = None , keepdims = False ):
999
992
x = get_ov_output (x )
1000
993
original_type = x .get_element_type ()
1001
994
is_bool = original_type == Type .boolean
1002
-
1003
995
if is_bool :
1004
996
x = ov_opset .convert (x , Type .i32 ).output (0 )
1005
997
elif original_type not in (Type .f32 , Type .f64 ):
1006
998
x = ov_opset .convert (x , Type .f32 ).output (0 )
1007
-
1008
999
if axis is None :
1009
1000
x = ov_opset .reshape (x , ov_opset .constant ([- 1 ], Type .i32 ).output (0 ), False ).output (0 )
1010
1001
axis = 0
1011
-
1012
1002
shape = ov_opset .convert (ov_opset .shape_of (x ).output (0 ), Type .i64 ).output (0 )
1003
+ rank = x .get_partial_shape ().rank .get_length ()
1013
1004
axis_const = ov_opset .constant ([axis ], Type .i32 ).output (0 )
1014
1005
axis_length = ov_opset .reshape (
1015
1006
ov_opset .gather (shape , axis_const , 0 ).output (0 ),
1016
1007
ov_opset .constant ([], Type .i32 ).output (0 ),
1017
1008
False
1018
1009
).output (0 )
1019
-
1020
1010
const_zero = ov_opset .constant (0 , Type .i64 ).output (0 )
1021
1011
is_empty = ov_opset .equal (axis_length , const_zero ).output (0 )
1022
1012
zero_value = ov_opset .constant (0.0 if not is_bool else 0 , x .get_element_type ()).output (0 )
1023
-
1024
- result_shape = shape
1025
- if keepdims :
1026
- result_shape = ov_opset .select (
1027
- is_empty ,
1028
- shape ,
1029
- ov_opset .scatter_elements_update (shape , axis_const , ov_opset .constant ([1 ], Type .i64 ).output (0 ), 0 ).output (0 )
1030
- ).output (0 )
1031
- elif axis is None and x .get_partial_shape ().rank .get_length () > 1 :
1032
- result_shape = ov_opset .constant ([], Type .i32 ).output (0 )
1033
-
1013
+ if axis is None :
1014
+ if keepdims :
1015
+ result_shape = ov_opset .constant ([1 ] * rank , Type .i32 ).output (0 )
1016
+ else :
1017
+ result_shape = ov_opset .constant ([], Type .i32 ).output (0 )
1018
+ else :
1019
+ if keepdims :
1020
+ one_i64 = ov_opset .constant ([1 ], Type .i64 ).output (0 )
1021
+ result_shape = ov_opset .scatter_elements_update (shape , axis_const , one_i64 , 0 ).output (0 )
1022
+ else :
1023
+ kept_axes = [i for i in range (rank ) if i != axis ]
1024
+ kept_const = ov_opset .constant (kept_axes , Type .i32 ).output (0 )
1025
+ result_shape = ov_opset .gather (shape , kept_const , 0 ).output (0 )
1034
1026
empty_result = ov_opset .reshape (zero_value , result_shape , False ).output (0 )
1035
-
1036
1027
sorted_values = ov_opset .topk (x , axis_length , axis , "min" , "value" ).output (0 )
1037
-
1038
1028
const_one = ov_opset .constant (1 , Type .i64 ).output (0 )
1039
- is_odd = ov_opset .equal ( ov_opset . floor_mod (axis_length , ov_opset .constant (2 , Type .i64 ).output (0 )). output ( 0 ), const_one ).output (0 )
1040
-
1029
+ mod_two = ov_opset .floor_mod (axis_length , ov_opset .constant (2 , Type .i64 ).output (0 )).output (0 )
1030
+ is_odd = ov_opset . equal ( mod_two , const_one ). output ( 0 )
1041
1031
half = ov_opset .floor (ov_opset .divide (axis_length , ov_opset .constant (2 , Type .i64 ).output (0 )).output (0 )).output (0 )
1042
1032
half = ov_opset .convert (half , Type .i64 ).output (0 )
1043
-
1044
1033
mid_index = ov_opset .convert (half , Type .i32 ).output (0 )
1045
1034
prev_index = ov_opset .convert (ov_opset .subtract (half , const_one ).output (0 ), Type .i32 ).output (0 )
1046
-
1047
1035
mid_elem = ov_opset .gather (sorted_values , mid_index , axis ).output (0 )
1048
1036
prev_elem = ov_opset .gather (sorted_values , prev_index , axis ).output (0 )
1049
-
1050
1037
if is_bool :
1051
1038
sum_middle = ov_opset .add (mid_elem , prev_elem ).output (0 )
1052
1039
is_two = ov_opset .equal (sum_middle , ov_opset .constant (2 , Type .i32 ).output (0 )).output (0 )
1053
1040
is_one = ov_opset .equal (sum_middle , ov_opset .constant (1 , Type .i32 ).output (0 )).output (0 )
1054
- even_result = ov_opset .select (is_two , ov_opset .constant (1 , Type .i32 ).output (0 ),
1055
- ov_opset .select (is_one , ov_opset .constant (1 , Type .i32 ).output (0 ),
1056
- ov_opset .constant (0 , Type .i32 ).output (0 ))).output (0 )
1041
+ even_result = ov_opset .select (
1042
+ is_two ,
1043
+ ov_opset .constant (1 , Type .i32 ).output (0 ),
1044
+ ov_opset .select (is_one , ov_opset .constant (1 , Type .i32 ).output (0 ), ov_opset .constant (0 , Type .i32 ).output (0 ))
1045
+ ).output (0 )
1057
1046
else :
1058
1047
even_result = ov_opset .divide (
1059
1048
ov_opset .add (mid_elem , prev_elem ).output (0 ),
1060
1049
ov_opset .constant (2.0 , x .get_element_type ()).output (0 )
1061
1050
).output (0 )
1062
-
1063
1051
median_result = ov_opset .select (is_odd , mid_elem , even_result ).output (0 )
1064
-
1065
- if keepdims or (axis is None and x .get_partial_shape ().rank .get_length () > 1 ):
1052
+ if keepdims or (axis is None and rank > 1 ):
1066
1053
median_result = ov_opset .reshape (median_result , result_shape , False ).output (0 )
1067
-
1068
1054
final_result = ov_opset .select (is_empty , empty_result , median_result ).output (0 )
1069
1055
return OpenVINOKerasTensor (ov_opset .convert (final_result , original_type ).output (0 ))
1070
1056
0 commit comments