Skip to content

Commit be3c114

Browse files
committed
fixed linspace
1 parent 4e4a0b3 commit be3c114

File tree

1 file changed

+93
-79
lines changed

1 file changed

+93
-79
lines changed

keras/src/backend/openvino/numpy.py

+93-79
Original file line numberDiff line numberDiff line change
@@ -913,101 +913,115 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
913913
if num < 0:
914914
raise ValueError("num must be non-negative")
915915

916-
if num == 0:
917-
if dtype is None:
918-
ov_dtype = OPENVINO_DTYPES[config.floatx()]
919-
else:
920-
ov_dtype = OPENVINO_DTYPES[dtype]
921-
result = ov_opset.constant([], ov_dtype, shape=[0]).output(0)
922-
if retstep:
923-
step = ov_opset.constant(float("nan"), ov_dtype).output(0)
924-
return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)
925-
return OpenVINOKerasTensor(result)
926-
927-
start = get_ov_output(start)
928-
stop = get_ov_output(stop)
929-
916+
start_ov = get_ov_output(start)
917+
stop_ov = get_ov_output(stop)
918+
930919
if dtype is None:
931920
ov_dtype = OPENVINO_DTYPES[config.floatx()]
932921
else:
933922
ov_dtype = OPENVINO_DTYPES[dtype]
934-
935-
start = ov_opset.convert(start, ov_dtype).output(0)
936-
stop = ov_opset.convert(stop, ov_dtype).output(0)
937-
938-
if num == 1:
939-
if endpoint:
940-
result = ov_opset.convert(stop, ov_dtype).output(0)
923+
924+
start_ov = ov_opset.convert(start_ov, ov_dtype).output(0)
925+
stop_ov = ov_opset.convert(stop_ov, ov_dtype).output(0)
926+
927+
start_shape = start_ov.get_shape()
928+
stop_shape = stop_ov.get_shape()
929+
930+
if num == 0:
931+
if len(start_shape) == 0 and len(stop_shape) == 0:
932+
result = ov_opset.constant(np.array([], dtype=np.dtype(dtype or config.floatx()))).output(0)
941933
else:
942-
result = ov_opset.convert(start, ov_dtype).output(0)
943-
if axis != 0:
944-
axis_const = ov_opset.constant([axis], Type.i64).output(0)
945-
result = ov_opset.unsqueeze(result, axis_const).output(0)
934+
out_shape = list(np.broadcast(
935+
np.empty(start_shape, dtype=bool),
936+
np.empty(stop_shape, dtype=bool)
937+
).shape)
938+
out_shape.insert(axis, 0)
939+
940+
empty_np = np.empty(out_shape, dtype=np.dtype(dtype or config.floatx()))
941+
empty_np = np.reshape(empty_np, [-1])
942+
result = ov_opset.constant(empty_np).output(0)
943+
944+
shape_const = ov_opset.constant(np.array(out_shape, dtype=np.int64)).output(0)
945+
result = ov_opset.reshape(result, shape_const).output(0)
946+
946947
if retstep:
947-
step = ov_opset.subtract(stop, start).output(0)
948+
delta = ov_opset.subtract(stop_ov, start_ov).output(0)
949+
step = delta
948950
return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)
949951
return OpenVINOKerasTensor(result)
950-
951-
div = num - 1 if endpoint else num
952-
div_const = ov_opset.constant(div, ov_dtype).output(0)
953-
delta = ov_opset.subtract(stop, start).output(0)
954-
step = ov_opset.divide(delta, div_const).output(0)
955-
956-
type_to_str = {
957-
Type.f16: "f16",
958-
Type.f32: "f32",
959-
Type.f64: "f64",
960-
Type.bf16: "bf16",
961-
Type.i8: "i8",
962-
Type.i16: "i16",
963-
Type.i32: "i32",
964-
Type.i64: "i64",
965-
Type.u8: "u8",
966-
Type.u16: "u16",
967-
Type.u32: "u32",
968-
Type.u64: "u64"
969-
}
970952

971-
type_str = type_to_str.get(ov_dtype, "f32")
953+
is_scalar_start = len(start_shape) == 0
954+
is_scalar_stop = len(stop_shape) == 0
972955

973-
indices = ov_opset.range(
974-
ov_opset.constant(0, Type.i32).output(0),
975-
ov_opset.constant(num, Type.i32).output(0),
976-
ov_opset.constant(1, Type.i32).output(0),
977-
type_str
978-
).output(0)
956+
if not (is_scalar_start and is_scalar_stop):
957+
broadcast_shape = list(np.broadcast(
958+
np.empty(start_shape, dtype=bool),
959+
np.empty(stop_shape, dtype=bool)
960+
).shape)
961+
962+
if not is_scalar_start and tuple(start_shape) != tuple(broadcast_shape):
963+
shape_const = ov_opset.constant(np.array(broadcast_shape, dtype=np.int64)).output(0)
964+
start_ov = ov_opset.broadcast(start_ov, shape_const).output(0)
965+
966+
if not is_scalar_stop and tuple(stop_shape) != tuple(broadcast_shape):
967+
shape_const = ov_opset.constant(np.array(broadcast_shape, dtype=np.int64)).output(0)
968+
stop_ov = ov_opset.broadcast(stop_ov, shape_const).output(0)
979969

980-
scaled_indices = ov_opset.multiply(indices, step).output(0)
981-
result = ov_opset.add(start, scaled_indices).output(0)
982-
983-
if endpoint and num > 1:
984-
all_but_last = ov_opset.slice(
985-
result,
986-
ov_opset.constant([0], Type.i64).output(0),
987-
ov_opset.constant([num-1], Type.i64).output(0),
988-
ov_opset.constant([1], Type.i64).output(0),
989-
ov_opset.constant([0], Type.i64).output(0)
990-
).output(0)
970+
if num == 1:
971+
if endpoint:
972+
result = stop_ov
973+
else:
974+
result = start_ov
975+
976+
step = ov_opset.subtract(stop_ov, start_ov).output(0)
991977

992-
stop_shape = stop.get_shape()
993-
result_shape = result.get_shape()
978+
if not (is_scalar_start and is_scalar_stop):
979+
out_shape = list(result.get_shape())
980+
out_shape.insert(axis, 1)
981+
shape_const = ov_opset.constant(np.array(out_shape, dtype=np.int64)).output(0)
982+
result = ov_opset.reshape(result, shape_const).output(0)
983+
else:
984+
div = num - 1 if endpoint else num
985+
div_const = ov_opset.constant(div, ov_dtype).output(0)
986+
delta = ov_opset.subtract(stop_ov, start_ov).output(0)
987+
step = ov_opset.divide(delta, div_const).output(0)
994988

995-
if len(stop_shape) < len(result_shape):
996-
for _ in range(len(result_shape) - len(stop_shape)):
997-
stop = ov_opset.unsqueeze(
998-
stop,
999-
ov_opset.constant([0], Type.i64).output(0)
1000-
).output(0)
989+
out_shape = list(start_ov.get_shape() if not is_scalar_start else stop_ov.get_shape() if not is_scalar_stop else [])
1001990

1002-
result = ov_opset.concat([all_but_last, stop], 0).output(0)
1003-
1004-
if axis != 0:
1005-
axis_const = ov_opset.constant([axis], Type.i64).output(0)
1006-
result = ov_opset.unsqueeze(result, axis_const).output(0)
1007-
991+
indices = ov_opset.range(
992+
ov_opset.constant(0, ov_dtype).output(0),
993+
ov_opset.constant(num, ov_dtype).output(0),
994+
ov_opset.constant(1, ov_dtype).output(0)
995+
).output(0)
996+
997+
if not (is_scalar_start and is_scalar_stop):
998+
expanded_shape = list(out_shape)
999+
expanded_shape.insert(axis, 1)
1000+
shape_const = ov_opset.constant(np.array(expanded_shape, dtype=np.int64)).output(0)
1001+
1002+
start_reshaped = ov_opset.reshape(start_ov, shape_const).output(0)
1003+
step_reshaped = ov_opset.reshape(step, shape_const).output(0)
1004+
1005+
indices_shape = [1] * len(expanded_shape)
1006+
indices_shape[axis] = num
1007+
indices_shape_const = ov_opset.constant(np.array(indices_shape, dtype=np.int64)).output(0)
1008+
indices_reshaped = ov_opset.reshape(indices, indices_shape_const).output(0)
1009+
1010+
indices_times_step = ov_opset.multiply(indices_reshaped, step_reshaped).output(0)
1011+
result = ov_opset.add(start_reshaped, indices_times_step).output(0)
1012+
else:
1013+
indices_times_step = ov_opset.multiply(indices, step).output(0)
1014+
result = ov_opset.add(start_ov, indices_times_step).output(0)
1015+
1016+
if axis != 0:
1017+
out_shape = [1] * (axis+1)
1018+
out_shape[axis] = num
1019+
shape_const = ov_opset.constant(np.array(out_shape, dtype=np.int64)).output(0)
1020+
result = ov_opset.reshape(result, shape_const).output(0)
1021+
10081022
if retstep:
10091023
return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)
1010-
return OpenVINOKerasTensor(result)
1024+
return OpenVINOKerasTensor(result)
10111025

10121026

10131027
def log(x):

0 commit comments

Comments
 (0)