@@ -913,101 +913,115 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
913
913
if num < 0 :
914
914
raise ValueError ("num must be non-negative" )
915
915
916
- if num == 0 :
917
- if dtype is None :
918
- ov_dtype = OPENVINO_DTYPES [config .floatx ()]
919
- else :
920
- ov_dtype = OPENVINO_DTYPES [dtype ]
921
- result = ov_opset .constant ([], ov_dtype , shape = [0 ]).output (0 )
922
- if retstep :
923
- step = ov_opset .constant (float ("nan" ), ov_dtype ).output (0 )
924
- return OpenVINOKerasTensor (result ), OpenVINOKerasTensor (step )
925
- return OpenVINOKerasTensor (result )
926
-
927
- start = get_ov_output (start )
928
- stop = get_ov_output (stop )
929
-
916
+ start_ov = get_ov_output (start )
917
+ stop_ov = get_ov_output (stop )
918
+
930
919
if dtype is None :
931
920
ov_dtype = OPENVINO_DTYPES [config .floatx ()]
932
921
else :
933
922
ov_dtype = OPENVINO_DTYPES [dtype ]
934
-
935
- start = ov_opset .convert (start , ov_dtype ).output (0 )
936
- stop = ov_opset .convert (stop , ov_dtype ).output (0 )
937
-
938
- if num == 1 :
939
- if endpoint :
940
- result = ov_opset .convert (stop , ov_dtype ).output (0 )
923
+
924
+ start_ov = ov_opset .convert (start_ov , ov_dtype ).output (0 )
925
+ stop_ov = ov_opset .convert (stop_ov , ov_dtype ).output (0 )
926
+
927
+ start_shape = start_ov .get_shape ()
928
+ stop_shape = stop_ov .get_shape ()
929
+
930
+ if num == 0 :
931
+ if len (start_shape ) == 0 and len (stop_shape ) == 0 :
932
+ result = ov_opset .constant (np .array ([], dtype = np .dtype (dtype or config .floatx ()))).output (0 )
941
933
else :
942
- result = ov_opset .convert (start , ov_dtype ).output (0 )
943
- if axis != 0 :
944
- axis_const = ov_opset .constant ([axis ], Type .i64 ).output (0 )
945
- result = ov_opset .unsqueeze (result , axis_const ).output (0 )
934
+ out_shape = list (np .broadcast (
935
+ np .empty (start_shape , dtype = bool ),
936
+ np .empty (stop_shape , dtype = bool )
937
+ ).shape )
938
+ out_shape .insert (axis , 0 )
939
+
940
+ empty_np = np .empty (out_shape , dtype = np .dtype (dtype or config .floatx ()))
941
+ empty_np = np .reshape (empty_np , [- 1 ])
942
+ result = ov_opset .constant (empty_np ).output (0 )
943
+
944
+ shape_const = ov_opset .constant (np .array (out_shape , dtype = np .int64 )).output (0 )
945
+ result = ov_opset .reshape (result , shape_const ).output (0 )
946
+
946
947
if retstep :
947
- step = ov_opset .subtract (stop , start ).output (0 )
948
+ delta = ov_opset .subtract (stop_ov , start_ov ).output (0 )
949
+ step = delta
948
950
return OpenVINOKerasTensor (result ), OpenVINOKerasTensor (step )
949
951
return OpenVINOKerasTensor (result )
950
-
951
- div = num - 1 if endpoint else num
952
- div_const = ov_opset .constant (div , ov_dtype ).output (0 )
953
- delta = ov_opset .subtract (stop , start ).output (0 )
954
- step = ov_opset .divide (delta , div_const ).output (0 )
955
-
956
- type_to_str = {
957
- Type .f16 : "f16" ,
958
- Type .f32 : "f32" ,
959
- Type .f64 : "f64" ,
960
- Type .bf16 : "bf16" ,
961
- Type .i8 : "i8" ,
962
- Type .i16 : "i16" ,
963
- Type .i32 : "i32" ,
964
- Type .i64 : "i64" ,
965
- Type .u8 : "u8" ,
966
- Type .u16 : "u16" ,
967
- Type .u32 : "u32" ,
968
- Type .u64 : "u64"
969
- }
970
952
971
- type_str = type_to_str .get (ov_dtype , "f32" )
953
+ is_scalar_start = len (start_shape ) == 0
954
+ is_scalar_stop = len (stop_shape ) == 0
972
955
973
- indices = ov_opset .range (
974
- ov_opset .constant (0 , Type .i32 ).output (0 ),
975
- ov_opset .constant (num , Type .i32 ).output (0 ),
976
- ov_opset .constant (1 , Type .i32 ).output (0 ),
977
- type_str
978
- ).output (0 )
956
+ if not (is_scalar_start and is_scalar_stop ):
957
+ broadcast_shape = list (np .broadcast (
958
+ np .empty (start_shape , dtype = bool ),
959
+ np .empty (stop_shape , dtype = bool )
960
+ ).shape )
961
+
962
+ if not is_scalar_start and tuple (start_shape ) != tuple (broadcast_shape ):
963
+ shape_const = ov_opset .constant (np .array (broadcast_shape , dtype = np .int64 )).output (0 )
964
+ start_ov = ov_opset .broadcast (start_ov , shape_const ).output (0 )
965
+
966
+ if not is_scalar_stop and tuple (stop_shape ) != tuple (broadcast_shape ):
967
+ shape_const = ov_opset .constant (np .array (broadcast_shape , dtype = np .int64 )).output (0 )
968
+ stop_ov = ov_opset .broadcast (stop_ov , shape_const ).output (0 )
979
969
980
- scaled_indices = ov_opset .multiply (indices , step ).output (0 )
981
- result = ov_opset .add (start , scaled_indices ).output (0 )
982
-
983
- if endpoint and num > 1 :
984
- all_but_last = ov_opset .slice (
985
- result ,
986
- ov_opset .constant ([0 ], Type .i64 ).output (0 ),
987
- ov_opset .constant ([num - 1 ], Type .i64 ).output (0 ),
988
- ov_opset .constant ([1 ], Type .i64 ).output (0 ),
989
- ov_opset .constant ([0 ], Type .i64 ).output (0 )
990
- ).output (0 )
970
+ if num == 1 :
971
+ if endpoint :
972
+ result = stop_ov
973
+ else :
974
+ result = start_ov
975
+
976
+ step = ov_opset .subtract (stop_ov , start_ov ).output (0 )
991
977
992
- stop_shape = stop .get_shape ()
993
- result_shape = result .get_shape ()
978
+ if not (is_scalar_start and is_scalar_stop ):
979
+ out_shape = list (result .get_shape ())
980
+ out_shape .insert (axis , 1 )
981
+ shape_const = ov_opset .constant (np .array (out_shape , dtype = np .int64 )).output (0 )
982
+ result = ov_opset .reshape (result , shape_const ).output (0 )
983
+ else :
984
+ div = num - 1 if endpoint else num
985
+ div_const = ov_opset .constant (div , ov_dtype ).output (0 )
986
+ delta = ov_opset .subtract (stop_ov , start_ov ).output (0 )
987
+ step = ov_opset .divide (delta , div_const ).output (0 )
994
988
995
- if len (stop_shape ) < len (result_shape ):
996
- for _ in range (len (result_shape ) - len (stop_shape )):
997
- stop = ov_opset .unsqueeze (
998
- stop ,
999
- ov_opset .constant ([0 ], Type .i64 ).output (0 )
1000
- ).output (0 )
989
+ out_shape = list (start_ov .get_shape () if not is_scalar_start else stop_ov .get_shape () if not is_scalar_stop else [])
1001
990
1002
- result = ov_opset .concat ([all_but_last , stop ], 0 ).output (0 )
1003
-
1004
- if axis != 0 :
1005
- axis_const = ov_opset .constant ([axis ], Type .i64 ).output (0 )
1006
- result = ov_opset .unsqueeze (result , axis_const ).output (0 )
1007
-
991
+ indices = ov_opset .range (
992
+ ov_opset .constant (0 , ov_dtype ).output (0 ),
993
+ ov_opset .constant (num , ov_dtype ).output (0 ),
994
+ ov_opset .constant (1 , ov_dtype ).output (0 )
995
+ ).output (0 )
996
+
997
+ if not (is_scalar_start and is_scalar_stop ):
998
+ expanded_shape = list (out_shape )
999
+ expanded_shape .insert (axis , 1 )
1000
+ shape_const = ov_opset .constant (np .array (expanded_shape , dtype = np .int64 )).output (0 )
1001
+
1002
+ start_reshaped = ov_opset .reshape (start_ov , shape_const ).output (0 )
1003
+ step_reshaped = ov_opset .reshape (step , shape_const ).output (0 )
1004
+
1005
+ indices_shape = [1 ] * len (expanded_shape )
1006
+ indices_shape [axis ] = num
1007
+ indices_shape_const = ov_opset .constant (np .array (indices_shape , dtype = np .int64 )).output (0 )
1008
+ indices_reshaped = ov_opset .reshape (indices , indices_shape_const ).output (0 )
1009
+
1010
+ indices_times_step = ov_opset .multiply (indices_reshaped , step_reshaped ).output (0 )
1011
+ result = ov_opset .add (start_reshaped , indices_times_step ).output (0 )
1012
+ else :
1013
+ indices_times_step = ov_opset .multiply (indices , step ).output (0 )
1014
+ result = ov_opset .add (start_ov , indices_times_step ).output (0 )
1015
+
1016
+ if axis != 0 :
1017
+ out_shape = [1 ] * (axis + 1 )
1018
+ out_shape [axis ] = num
1019
+ shape_const = ov_opset .constant (np .array (out_shape , dtype = np .int64 )).output (0 )
1020
+ result = ov_opset .reshape (result , shape_const ).output (0 )
1021
+
1008
1022
if retstep :
1009
1023
return OpenVINOKerasTensor (result ), OpenVINOKerasTensor (step )
1010
- return OpenVINOKerasTensor (result )
1024
+ return OpenVINOKerasTensor (result )
1011
1025
1012
1026
1013
1027
def log (x ):
0 commit comments