Skip to content

Commit 50dae30

Browse files
authored
Add sparse support to ops.ones_like and ops.zeros_like. (#21181)
`ops.zeros_like` is in particular useful for creating a mask of the populated values in the sparse tensor.
1 parent 65f7e6d commit 50dae30

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,10 +895,12 @@ def not_equal(x1, x2):
895895
return jnp.not_equal(x1, x2)
896896

897897

898+
@sparse.elementwise_unary(linear=False)
898899
def ones_like(x, dtype=None):
899900
return jnp.ones_like(x, dtype=dtype)
900901

901902

903+
@sparse.elementwise_unary(linear=True)
902904
def zeros_like(x, dtype=None):
903905
return jnp.zeros_like(x, dtype=dtype)
904906

keras/src/backend/tensorflow/numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,10 +1798,12 @@ def not_equal(x1, x2):
17981798
return tf.not_equal(x1, x2)
17991799

18001800

1801+
@sparse.elementwise_unary
18011802
def ones_like(x, dtype=None):
18021803
return tf.ones_like(x, dtype=dtype)
18031804

18041805

1806+
@sparse.elementwise_unary
18051807
def zeros_like(x, dtype=None):
18061808
return tf.zeros_like(x, dtype=dtype)
18071809

keras/src/ops/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4391,7 +4391,8 @@ def call(self, x, dtype=None):
43914391
def compute_output_spec(self, x, dtype=None):
43924392
if dtype is None:
43934393
dtype = x.dtype
4394-
return KerasTensor(x.shape, dtype=dtype)
4394+
sparse = getattr(x, "sparse", False)
4395+
return KerasTensor(x.shape, dtype=dtype, sparse=sparse)
43954396

43964397

43974398
@keras_export(["keras.ops.ones_like", "keras.ops.numpy.ones_like"])
@@ -4417,7 +4418,8 @@ def call(self, x, dtype=None):
44174418
def compute_output_spec(self, x, dtype=None):
44184419
if dtype is None:
44194420
dtype = x.dtype
4420-
return KerasTensor(x.shape, dtype=dtype)
4421+
sparse = getattr(x, "sparse", False)
4422+
return KerasTensor(x.shape, dtype=dtype, sparse=sparse)
44214423

44224424

44234425
@keras_export(

keras/src/ops/numpy_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5097,6 +5097,7 @@ class SparseTest(testing.TestCase):
50975097
"imag",
50985098
"log1p",
50995099
"negative",
5100+
"ones_like",
51005101
"real",
51015102
"round",
51025103
"sign",
@@ -5106,6 +5107,7 @@ class SparseTest(testing.TestCase):
51065107
"square",
51075108
"tan",
51085109
"tanh",
5110+
"zeros_like",
51095111
]
51105112
ELEMENTWISE_UNARY_OPS_TESTS = [
51115113
{
@@ -5287,10 +5289,11 @@ def test_elementwise_unary_sparse_correctness(
52875289
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
52885290
x = create_sparse_tensor(x)
52895291
x_np = backend.convert_to_numpy(x)
5292+
expected = np_op(x_np) * backend.convert_to_numpy(knp.ones_like(x))
52905293

5291-
self.assertAllClose(op_function(x), np_op(x_np))
5294+
self.assertAllClose(op_function(x), expected)
52925295
self.assertSameSparseness(op_function(x), x)
5293-
self.assertAllClose(op_class()(x), np_op(x_np))
5296+
self.assertAllClose(op_class()(x), expected)
52945297
self.assertSameSparseness(op_class()(x), x)
52955298

52965299
@parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS)
@@ -5303,10 +5306,11 @@ def test_elementwise_unary_indexed_slices_correctness(
53035306
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
53045307
x = create_indexed_slices(x)
53055308
x_np = backend.convert_to_numpy(x)
5309+
expected = np_op(x_np) * backend.convert_to_numpy(knp.ones_like(x))
53065310

5307-
self.assertAllClose(op_function(x), np_op(x_np))
5311+
self.assertAllClose(op_function(x), expected)
53085312
self.assertSameSparseness(op_function(x), x)
5309-
self.assertAllClose(op_class()(x), np_op(x_np))
5313+
self.assertAllClose(op_class()(x), expected)
53105314
self.assertSameSparseness(op_class()(x), x)
53115315

53125316
@parameterized.named_parameters(OTHER_UNARY_OPS_TESTS)

0 commit comments

Comments
 (0)