diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index c5770a93c49d..da80815a3c65 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -18,6 +18,7 @@ from keras.src.ops.core import fori_loop as fori_loop from keras.src.ops.core import is_tensor as is_tensor from keras.src.ops.core import map as map +from keras.src.ops.core import print as print from keras.src.ops.core import saturate_cast as saturate_cast from keras.src.ops.core import scan as scan from keras.src.ops.core import scatter as scatter diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index c5770a93c49d..da80815a3c65 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -18,6 +18,7 @@ from keras.src.ops.core import fori_loop as fori_loop from keras.src.ops.core import is_tensor as is_tensor from keras.src.ops.core import map as map +from keras.src.ops.core import print as print from keras.src.ops.core import saturate_cast as saturate_cast from keras.src.ops.core import scan as scan from keras.src.ops.core import scatter as scatter diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 747c5881106b..7db9fae3ea16 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -1,3 +1,4 @@ +from jax.debug import print # noqa import jax import jax.experimental.sparse as jax_sparse import jax.numpy as jnp diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 16b2303e5e43..a9441af891cc 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -18,6 +18,8 @@ SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True +_print = print + class Variable(KerasVariable): def _initialize(self, value): @@ -452,3 +454,10 @@ def remat(f): "utilize this feature." ) return f + + +def print(*args, print_options=None, **kwargs): + np.set_printoptions( + **{"threshold": 1000} if print_options is None else print_options + ) + return _print(*args, **kwargs) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index ec990c376bf3..0595f600bc0f 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -22,6 +22,8 @@ SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True +_print = print + OPENVINO_DTYPES = { "float16": ov.Type.f16, "float32": ov.Type.f32, @@ -664,3 +666,7 @@ def remat(f): "utilize this feature." ) return f + + +def print(*args, **kwargs): + return _print(*args, **kwargs) diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 6896b74c519c..7843c309cb7e 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -2,6 +2,7 @@ import numpy as np import tensorflow as tf +from tensorflow import print # noqa from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from keras.src import tree diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 6fb2ab4eeebb..c747595c8796 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -23,6 +23,8 @@ SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True +_print = print + # Some operators such as 'aten::_foreach_mul_.Scalar' # are not currently implemented for the MPS device. # check https://github.com/pytorch/pytorch/issues/77764. @@ -733,3 +735,10 @@ def backward(ctx, grad_output): if not isinstance(grads, tuple): grads = (grads,) return (None,) + grads + + +def print(*args, print_options=None, **kwargs): + torch.set_printoptions( + **{"threshold": 1000} if print_options is None else print_options + ) + return _print(*args, **kwargs) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 74807b280eae..cc1076f3febf 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -1183,3 +1183,10 @@ def grad(*args, upstream): ``` """ return backend.core.custom_gradient(f) + + +@keras_export("keras.ops.print") +def print(*args, **kwargs): + """Backend-specialised print function, oft handles tensors and + other backend-specific types.""" + return backend.core.print(*args, **kwargs)