Skip to content

Commit 3bf075b

Browse files
Merge branch 'keras-team:master' into fix_tpu_tests
2 parents 575b01e + 46813a3 commit 3bf075b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1037
-528
lines changed
Lines changed: 9 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,3 @@
1-
AdditiveAttentionTest::test_attention_correctness
2-
AttentionTest::test_attention_calculate_scores_with_scale
3-
AttentionTest::test_attention_correctness
4-
CircleTest::test_correctness
5-
CircleTest::test_correctness_weighted
6-
CircleTest::test_mean_with_sample_weight_reduction
7-
CircleTest::test_no_reduction
8-
CircleTest::test_sum_reduction
9-
ConvBasicTest::test_enable_lora_with_alpha
10-
ConvCorrectnessTest::test_conv1d0
11-
ConvCorrectnessTest::test_conv1d1
12-
ConvCorrectnessTest::test_conv1d2
13-
ConvCorrectnessTest::test_conv1d3
14-
ConvCorrectnessTest::test_conv1d4
15-
ConvCorrectnessTest::test_conv2d0
16-
ConvCorrectnessTest::test_conv2d1
17-
ConvCorrectnessTest::test_conv2d2
18-
ConvCorrectnessTest::test_conv2d3
19-
ConvCorrectnessTest::test_conv2d4
20-
ConvCorrectnessTest::test_conv2d5
21-
ConvCorrectnessTest::test_conv3d0
22-
ConvCorrectnessTest::test_conv3d1
23-
ConvCorrectnessTest::test_conv3d2
24-
ConvCorrectnessTest::test_conv3d3
25-
ConvCorrectnessTest::test_conv3d4
26-
ConvLSTM1DTest::test_correctness
27-
ConvLSTM1DTest::test_correctness
28-
ConvLSTM2DTest::test_correctness
29-
ConvLSTMCellTest::test_correctness
30-
ConvLSTMTest::test_correctness
31-
ConvTransposeCorrectnessTest::test_conv1d_transpose0
32-
ConvTransposeCorrectnessTest::test_conv1d_transpose1
33-
ConvTransposeCorrectnessTest::test_conv1d_transpose2
34-
ConvTransposeCorrectnessTest::test_conv2d_transpose0
35-
ConvTransposeCorrectnessTest::test_conv2d_transpose1
36-
ConvTransposeCorrectnessTest::test_conv2d_transpose2
37-
ConvTransposeCorrectnessTest::test_conv2d_transpose3
38-
ConvTransposeCorrectnessTest::test_conv3d_transpose0
39-
ConvTransposeCorrectnessTest::test_conv3d_transpose1
40-
ConvTransposeCorrectnessTest::test_conv3d_transpose2
41-
CTCTest::test_correctness
42-
DenseTest::test_dense_sparse
43-
DepthwiseConvCorrectnessTest::test_depthwise_conv1d0
44-
DepthwiseConvCorrectnessTest::test_depthwise_conv1d1
45-
DepthwiseConvCorrectnessTest::test_depthwise_conv1d2
46-
DepthwiseConvCorrectnessTest::test_depthwise_conv2d0
47-
DepthwiseConvCorrectnessTest::test_depthwise_conv2d1
48-
DepthwiseConvCorrectnessTest::test_depthwise_conv2d2
49-
EinsumDenseTest::test_enable_lora_with_alpha
50-
EmbeddingTest::test_enable_lora_with_alpha
511
ExportArchiveTest::test_jax_endpoint_registration_tf_function
522
ExportArchiveTest::test_jax_multi_unknown_endpoint_registration
533
ExportArchiveTest::test_layer_export
@@ -71,10 +21,18 @@ ExportArchiveTest::test_track_multiple_layers
7121
ExportONNXTest::test_export_with_input_names
7222
ExportONNXTest::test_export_with_opset_version_18
7323
ExportONNXTest::test_export_with_opset_version_none
24+
ExportONNXTest::test_model_with_input_structure_array
25+
ExportONNXTest::test_model_with_input_structure_dict
26+
ExportONNXTest::test_model_with_input_structure_tuple
27+
ExportONNXTest::test_model_with_multiple_inputs
7428
ExportONNXTest::test_standard_model_export_functional
7529
ExportONNXTest::test_standard_model_export_lstm
7630
ExportONNXTest::test_standard_model_export_sequential
7731
ExportONNXTest::test_standard_model_export_subclass
32+
ExportOpenVINOTest::test_model_with_input_structure_array
33+
ExportOpenVINOTest::test_model_with_input_structure_dict
34+
ExportOpenVINOTest::test_model_with_input_structure_tuple
35+
ExportOpenVINOTest::test_model_with_multiple_inputs
7836
ExportOpenVINOTest::test_standard_model_export_functional
7937
ExportOpenVINOTest::test_standard_model_export_sequential
8038
ExportOpenVINOTest::test_standard_model_export_subclass
@@ -118,117 +76,10 @@ ExportSavedModelTest::test_model_with_tf_data_layer_subclass
11876
ExportSavedModelTest::test_standard_model_export_functional
11977
ExportSavedModelTest::test_standard_model_export_sequential
12078
ExportSavedModelTest::test_standard_model_export_subclass
121-
GRUTest::test_correctness0
122-
GRUTest::test_correctness1
123-
GRUTest::test_legacy_implementation_argument
124-
GRUTest::test_masking
125-
GRUTest::test_pass_initial_state
126-
GRUTest::test_pass_return_state
127-
GRUTest::test_statefulness
128-
ImageOpsCorrectnessTest::test_affine_transform_bilinear_constant
129-
ImageOpsCorrectnessTest::test_affine_transform_bilinear_mirror
130-
ImageOpsCorrectnessTest::test_affine_transform_bilinear_nearest
131-
ImageOpsCorrectnessTest::test_affine_transform_bilinear_reflect
132-
ImageOpsCorrectnessTest::test_affine_transform_bilinear_wrap
133-
LinalgOpsCorrectnessTest::test_cholesky_inverse_lower
134-
LinalgOpsCorrectnessTest::test_cholesky_inverse_upper
135-
LinalgOpsCorrectnessTest::test_eig
136-
LinalgOpsCorrectnessTest::test_svd
137-
LSTMTest::test_correctness0
138-
LSTMTest::test_correctness1
139-
LSTMTest::test_masking
140-
LSTMTest::test_pass_initial_state
141-
LSTMTest::test_statefulness
142-
MathOpsCorrectnessTest::test_extract_sequences
143-
MergingLayersTest::test_correctness_dynamic_dot_3d
144-
MergingLayersTest::test_correctness_static_dot_3d
145-
MuonTest::test_Newton_Schulz
146-
NNOpsCorrectnessTest::test_conv_2d0
147-
NNOpsCorrectnessTest::test_conv_2d1
148-
NNOpsCorrectnessTest::test_conv_2d2
149-
NNOpsCorrectnessTest::test_conv_2d3
150-
NNOpsCorrectnessTest::test_conv_2d4
151-
NNOpsCorrectnessTest::test_conv_2d5
152-
NNOpsCorrectnessTest::test_conv_3d0
153-
NNOpsCorrectnessTest::test_conv_3d1
154-
NNOpsCorrectnessTest::test_conv_3d10
155-
NNOpsCorrectnessTest::test_conv_3d11
156-
NNOpsCorrectnessTest::test_conv_3d2
157-
NNOpsCorrectnessTest::test_conv_3d3
158-
NNOpsCorrectnessTest::test_conv_3d4
159-
NNOpsCorrectnessTest::test_conv_3d5
160-
NNOpsCorrectnessTest::test_conv_3d6
161-
NNOpsCorrectnessTest::test_conv_3d7
162-
NNOpsCorrectnessTest::test_conv_3d8
163-
NNOpsCorrectnessTest::test_conv_3d9
164-
NNOpsCorrectnessTest::test_ctc_loss
165-
NNOpsCorrectnessTest::test_depthwise_conv_2d0
166-
NNOpsCorrectnessTest::test_depthwise_conv_2d1
167-
NNOpsCorrectnessTest::test_depthwise_conv_2d10
168-
NNOpsCorrectnessTest::test_depthwise_conv_2d11
169-
NNOpsCorrectnessTest::test_depthwise_conv_2d2
170-
NNOpsCorrectnessTest::test_depthwise_conv_2d3
171-
NNOpsCorrectnessTest::test_depthwise_conv_2d4
172-
NNOpsCorrectnessTest::test_depthwise_conv_2d5
173-
NNOpsCorrectnessTest::test_depthwise_conv_2d6
174-
NNOpsCorrectnessTest::test_depthwise_conv_2d7
175-
NNOpsCorrectnessTest::test_depthwise_conv_2d8
176-
NNOpsCorrectnessTest::test_depthwise_conv_2d9
177-
NNOpsCorrectnessTest::test_separable_conv_2d0
178-
NNOpsCorrectnessTest::test_separable_conv_2d1
179-
NNOpsCorrectnessTest::test_separable_conv_2d2
180-
NNOpsCorrectnessTest::test_separable_conv_2d3
181-
NNOpsCorrectnessTest::test_separable_conv_2d4
182-
NNOpsCorrectnessTest::test_separable_conv_2d5
183-
NNOpsCorrectnessTest::test_separable_conv_2d6
184-
NNOpsCorrectnessTest::test_separable_conv_2d7
185-
NumpyOneInputOpsDynamicShapeTest::test_argmax_negative_zero
186-
NumpyOneInputOpsDynamicShapeTest::test_argmin_negative_zero
187-
NumpyTwoInputOpsCorrectnessTest::test_logspace
188-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float32_false_false
189-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float64_false_false
190-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float16_false_false
191-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float32_false_false
192-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float64_false_false
193-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float16_false_false
194-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float32_false_false
195-
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float64_false_false
196-
RandomGaussianBlurTest::test_random_erasing_basic
197-
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_large_scale
198-
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_small_scale
199-
RandomZoomTest::test_random_zoom_out_correctness
200-
RegularizersTest::test_orthogonal_regularizer
201-
RNNTest::test_go_backwards
202-
SeparableConvCorrectnessTest::test_separable_conv1d0
203-
SeparableConvCorrectnessTest::test_separable_conv1d1
204-
SeparableConvCorrectnessTest::test_separable_conv1d2
205-
SeparableConvCorrectnessTest::test_separable_conv2d0
206-
SeparableConvCorrectnessTest::test_separable_conv2d1
207-
SeparableConvCorrectnessTest::test_separable_conv2d2
208-
SimpleRNNTest::test_correctness
209-
SimpleRNNTest::test_correctness
210-
SimpleRNNTest::test_masking
211-
SimpleRNNTest::test_masking
212-
SimpleRNNTest::test_pass_initial_state
213-
SimpleRNNTest::test_pass_initial_state
214-
SimpleRNNTest::test_return_state
215-
SimpleRNNTest::test_statefulness
216-
SimpleRNNTest::test_statefulness
217-
StackedRNNTest::test_correctness_single_state_stack
218-
StackedRNNTest::test_correctness_two_states_stack
219-
StackedRNNTest::test_statefullness_single_state_stack
220-
StackedRNNTest::test_statefullness_two_states_stack
221-
TestFitLRSchedulesFlow::test_fit_lr_correctness
22279
TestJaxLayer::test_flax_layer_training_independent_bound_method
22380
TestJaxLayer::test_flax_layer_training_rng_state_no_method
22481
TestJaxLayer::test_flax_layer_training_rng_unbound_method
22582
TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy
22683
TestJaxLayer::test_jax_layer_training_independent
22784
TestJaxLayer::test_jax_layer_training_state
228-
TestJaxLayer::test_jax_layer_training_state_dtype_policy
229-
TestSpectrogram::test_spectrogram_error
230-
TestTrainer::test_loss_weights
231-
TestTrainer::test_nested_inputs
232-
TestTrainer::test_on_batch_methods_eager
233-
TestTrainer::test_on_batch_methods_graph_fn
234-
TestTrainer::test_on_batch_methods_jit
85+
TestJaxLayer::test_jax_layer_training_state_dtype_policy

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ NumpyDtypeTest::test_corrcoef
1919
NumpyDtypeTest::test_correlate
2020
NumpyDtypeTest::test_cross
2121
NumpyDtypeTest::test_cumprod
22-
NumpyDtypeTest::test_diag
22+
NumpyDtypeTest::test_diagflat
23+
NumpyDtypeTest::test_diagonal
2324
NumpyDtypeTest::test_einsum
2425
NumpyDtypeTest::test_exp2
2526
NumpyDtypeTest::test_flip
@@ -69,7 +70,7 @@ NumpyOneInputOpsCorrectnessTest::test_conj
6970
NumpyOneInputOpsCorrectnessTest::test_corrcoef
7071
NumpyOneInputOpsCorrectnessTest::test_correlate
7172
NumpyOneInputOpsCorrectnessTest::test_cumprod
72-
NumpyOneInputOpsCorrectnessTest::test_diag
73+
NumpyOneInputOpsCorrectnessTest::test_diagflat
7374
NumpyOneInputOpsCorrectnessTest::test_diagonal
7475
NumpyOneInputOpsCorrectnessTest::test_exp2
7576
NumpyOneInputOpsCorrectnessTest::test_flip

keras/src/backend/openvino/numpy.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,70 @@ def deg2rad(x):
760760

761761

762762
def diag(x, k=0):
763-
raise NotImplementedError("`diag` is not supported with openvino backend")
763+
x = get_ov_output(x)
764+
x_shape = x.get_partial_shape()
765+
rank = x_shape.rank.get_length()
766+
767+
if rank == 1:
768+
N_dim = x_shape[0]
769+
if not N_dim.is_static:
770+
raise ValueError(
771+
"diag requires input with static shape for 1D input."
772+
)
773+
N = N_dim.get_length()
774+
output_size = N + np.abs(k)
775+
out_shape = ov_opset.constant(
776+
[output_size, output_size], dtype=Type.i32
777+
).output(0)
778+
zeros_const = ov_opset.constant(0, x.get_element_type()).output(0)
779+
diag_matrix = ov_opset.broadcast(zeros_const, out_shape)
780+
781+
indices = []
782+
if k >= 0:
783+
for i in range(N):
784+
indices.append([i, i + k])
785+
else:
786+
for i in range(N):
787+
indices.append([i - k, i])
788+
789+
indices = np.array(indices, dtype=np.int32)
790+
indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)
791+
updated = ov_opset.scatter_nd_update(diag_matrix, indices_const, x)
792+
return OpenVINOKerasTensor(updated.output(0))
793+
794+
elif rank == 2:
795+
M_dim = x_shape[0]
796+
N_dim = x_shape[1]
797+
if not M_dim.is_static or not N_dim.is_static:
798+
raise ValueError(
799+
"diag requires input with static shape for 2D input."
800+
)
801+
M = M_dim.get_length()
802+
N = N_dim.get_length()
803+
804+
if k >= 0:
805+
L = np.minimum(M, N - k) if (N - k) > 0 else 0
806+
indices = [[i, i + k] for i in range(L)]
807+
else:
808+
L = np.minimum(M + k, N) if (M + k) > 0 else 0
809+
indices = [[i - k, i] for i in range(L)]
810+
811+
if L <= 0:
812+
keras_dtype = ov_to_keras_type(x.get_element_type())
813+
np_dtype = np.dtype(keras_dtype)
814+
empty_np = np.empty((0,), dtype=np_dtype)
815+
empty_const = ov_opset.constant(
816+
empty_np, x.get_element_type()
817+
).output(0)
818+
return OpenVINOKerasTensor(empty_const)
819+
820+
indices = np.array(indices, dtype=np.int32)
821+
indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)
822+
diag_vec = ov_opset.gather_nd(x, indices_const)
823+
return OpenVINOKerasTensor(diag_vec.output(0))
824+
825+
else:
826+
raise ValueError("diag supports only 1D or 2D tensors")
764827

765828

766829
def diagonal(x, offset=0, axis1=0, axis2=1):

keras/src/layers/attention/additive_attention_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,16 @@ def test_attention_correctness(self):
5050
return_attention_scores=True,
5151
)
5252
self.assertAllClose(
53-
output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3
53+
output,
54+
[[[1.727, 2.727], [2.272, 3.272]]],
55+
atol=1e-3,
56+
tpu_atol=1e-2,
5457
)
5558
self.assertAllClose(
56-
scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3
59+
scores,
60+
[[[0.636, 0.363], [0.363, 0.636]]],
61+
atol=1e-3,
62+
tpu_atol=1e-2,
5763
)
5864

5965
def test_attention_with_mask(self):

keras/src/layers/attention/attention_test.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,18 @@ def test_attention_correctness(self):
5353
return_attention_scores=True,
5454
)
5555
self.assertAllClose(
56-
output, [[[2.462, 3.462], [1.538, 2.538]]], atol=1e-3
56+
output,
57+
[[[2.462, 3.462], [1.538, 2.538]]],
58+
atol=1e-3,
59+
tpu_atol=1e-2,
60+
tpu_rtol=1e-2,
5761
)
5862
self.assertAllClose(
59-
scores, [[[0.269, 0.731], [0.731, 0.269]]], atol=1e-3
63+
scores,
64+
[[[0.269, 0.731], [0.731, 0.269]]],
65+
atol=1e-3,
66+
tpu_atol=1e-2,
67+
tpu_rtol=1e-2,
6068
)
6169

6270
# Concat.
@@ -66,10 +74,18 @@ def test_attention_correctness(self):
6674
return_attention_scores=True,
6775
)
6876
self.assertAllClose(
69-
output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3
77+
output,
78+
[[[1.727, 2.727], [2.272, 3.272]]],
79+
atol=1e-3,
80+
tpu_atol=1e-2,
81+
tpu_rtol=1e-2,
7082
)
7183
self.assertAllClose(
72-
scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3
84+
scores,
85+
[[[0.636, 0.363], [0.363, 0.636]]],
86+
atol=1e-3,
87+
tpu_atol=1e-2,
88+
tpu_rtol=1e-2,
7389
)
7490

7591
def test_attention_with_mask(self):
@@ -149,7 +165,9 @@ def test_attention_calculate_scores_with_scale(self):
149165
expected_scores = np.matmul(query, key.transpose((0, 2, 1)))
150166
expected_scores *= layer.scale.numpy()
151167
actual_scores = layer._calculate_scores(query, key)
152-
self.assertAllClose(actual_scores, expected_scores)
168+
self.assertAllClose(
169+
actual_scores, expected_scores, tpu_atol=1e-2, tpu_rtol=1e-2
170+
)
153171

154172
def test_attention_calculate_score_mask_no_causal_no_vmask(self):
155173
scores = np.random.random((2, 3, 4))

0 commit comments

Comments
 (0)