Skip to content

Commit 553521e

Browse files
Improve keras.Variable by exposing docstrings and ensuring consistency in the codebase (#20544)
* Improve `keras.Variable` by exposing docstrings and ensuring consistency in the codebase * Fix CI * Update docstrings
1 parent e0369f6 commit 553521e

File tree

10 files changed

+91
-53
lines changed

10 files changed

+91
-53
lines changed

keras/src/backend/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from keras.src.backend.common.symbolic_scope import SymbolicScope
2020
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
2121
from keras.src.backend.common.variables import AutocastScope
22+
from keras.src.backend.common.variables import Variable
2223
from keras.src.backend.common.variables import get_autocast_scope
2324
from keras.src.backend.common.variables import is_float_dtype
2425
from keras.src.backend.common.variables import is_int_dtype
@@ -35,25 +36,26 @@
3536
# Import backend functions.
3637
if backend() == "tensorflow":
3738
from keras.src.backend.tensorflow import * # noqa: F403
39+
from keras.src.backend.tensorflow.core import Variable as BackendVariable
3840
elif backend() == "jax":
3941
from keras.src.backend.jax import * # noqa: F403
42+
from keras.src.backend.jax.core import Variable as BackendVariable
4043
elif backend() == "torch":
4144
from keras.src.backend.torch import * # noqa: F403
45+
from keras.src.backend.torch.core import Variable as BackendVariable
4246

4347
distribution_lib = None
4448
elif backend() == "numpy":
4549
from keras.src.backend.numpy import * # noqa: F403
50+
from keras.src.backend.numpy.core import Variable as BackendVariable
4651

4752
distribution_lib = None
4853
else:
4954
raise ValueError(f"Unable to import backend : {backend()}")
5055

5156

52-
BackendVariable = Variable # noqa: F405
53-
54-
5557
@keras_export("keras.Variable")
56-
class Variable(BackendVariable):
58+
class Variable(BackendVariable): # noqa: F811
5759
pass
5860

5961

keras/src/backend/common/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras.src.backend.common import backend_utils
22
from keras.src.backend.common.dtypes import result_type
33
from keras.src.backend.common.variables import AutocastScope
4-
from keras.src.backend.common.variables import KerasVariable
4+
from keras.src.backend.common.variables import Variable as KerasVariable
55
from keras.src.backend.common.variables import get_autocast_scope
66
from keras.src.backend.common.variables import is_float_dtype
77
from keras.src.backend.common.variables import is_int_dtype

keras/src/backend/common/stateless_scope.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class StatelessScope:
88
99
The values of variables to be used inside the scope
1010
should be passed via the `state_mapping` argument, a
11-
list of tuples `(k, v)` where `k` is a `KerasVariable`
11+
list of tuples `(k, v)` where `k` is a `Variable`
1212
and `v` is the intended value for this variable
1313
(a backend tensor).
1414
@@ -39,21 +39,21 @@ def __init__(
3939
initialize_variables=True,
4040
):
4141
from keras.src import backend
42-
from keras.src.backend.common.variables import KerasVariable
42+
from keras.src.backend.common.variables import Variable
4343

4444
self.collect_losses = collect_losses
4545
self.initialize_variables = initialize_variables
4646
self.losses = []
4747
self.state_mapping = {}
4848
state_mapping = state_mapping or {}
4949
for k, v in state_mapping:
50-
if not isinstance(k, KerasVariable):
50+
if not isinstance(k, Variable):
5151
raise ValueError(
5252
"Invalid reference variable in StatelessScope: "
53-
"all keys in argument `mapping` must be KerasVariable "
53+
"all keys in argument `mapping` must be Variable "
5454
f"instances. Received instead: {k}"
5555
)
56-
if isinstance(v, KerasVariable):
56+
if isinstance(v, Variable):
5757
v = backend.cast(v.value, dtype=k.dtype)
5858
else:
5959
v = backend.convert_to_tensor(v, dtype=k.dtype)

keras/src/backend/common/stateless_scope_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_invalid_key_in_state_mapping(self):
4141
value1 = ops.ones(shape=(2,))
4242

4343
with self.assertRaisesRegex(
44-
ValueError, "all keys in argument `mapping` must be KerasVariable"
44+
ValueError, "all keys in argument `mapping` must be Variable"
4545
):
4646
StatelessScope(state_mapping=[(invalid_key, value1)])
4747

keras/src/backend/common/variables.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from keras.src.utils.naming import auto_name
1313

1414

15-
class KerasVariable:
15+
class Variable:
1616
"""Represents a backend-agnostic variable in Keras.
1717
1818
A `Variable` acts as a container for state. It holds a tensor value and can
@@ -30,17 +30,25 @@ class KerasVariable:
3030
dtype type (`"float32"` if never configured).
3131
trainable: Optional. Boolean indicating if variable is trainable.
3232
Defaults to `True`.
33+
autocast: Optional. Boolean indicating whether the variable supports
34+
autocasting. If `True`, the layer may first convert the variable
35+
to the compute data type when accessed. Defaults to `True`.
36+
aggregation: Optional. String specifying how a distributed variable will
37+
be aggregated. This serves as a semantic annotation, to be taken
38+
into account by downstream backends or users. Defaults to `"mean"`.
3339
name: Optional. A unique name for the variable. Automatically generated
3440
if not set.
3541
3642
Attributes:
37-
name: The name of the variable (string).
38-
path: The path of the variable within the Keras model or layer (string).
39-
dtype: The data type of the variable (string).
4043
shape: The shape of the variable (tuple of integers).
4144
ndim: The number of dimensions of the variable (integer).
45+
dtype: The data type of the variable (string).
4246
trainable: Whether the variable is trainable (boolean).
47+
autocast: Whether the variable supports autocasting (boolean).
48+
aggregation: How a distributed variable will be aggregated (string).
4349
value: The current value of the variable (NumPy array or tensor).
50+
name: The name of the variable (string).
51+
path: The path of the variable within the Keras model or layer (string).
4452
4553
Examples:
4654
@@ -101,20 +109,19 @@ def __init__(
101109
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
102110
f"Received: aggregation={aggregation}"
103111
)
104-
self.name = name
112+
self._name = name
105113
parent_path = current_path()
106114
if parent_path:
107-
self.path = current_path() + "/" + self.name
115+
self._path = current_path() + "/" + name
108116
else:
109-
self.path = self.name
110-
dtype = standardize_dtype(dtype)
111-
self._dtype = dtype
117+
self._path = name
118+
self._dtype = standardize_dtype(dtype)
112119
self._shape = None
113120
self._initializer = None
114121
self._regularizer = None
115122
self._constraint = None
116-
self._trainable = trainable
117-
self._autocast = autocast
123+
self._trainable = bool(trainable)
124+
self._autocast = bool(autocast)
118125
self._aggregation = aggregation
119126
# `self._overwrite_with_gradient` is an internal property to determine
120127
# whether this variable should be overwritten by the computed gradient.
@@ -163,7 +170,7 @@ def __init__(
163170
self._initialize_with_initializer(initializer)
164171
else:
165172
self._initialize(initializer)
166-
self._shape = tuple(self._value.shape)
173+
self._shape = self._validate_shape(self._value.shape)
167174
self._ndim = len(self._shape)
168175

169176
def _deferred_initialize(self):
@@ -201,10 +208,12 @@ def numpy(self):
201208

202209
@property
203210
def aggregation(self):
211+
"""The strategy for aggregating this variable."""
204212
return self._aggregation
205213

206214
@property
207215
def value(self):
216+
"""The current value of the variable (numpy array or backend tensor)."""
208217
if in_stateless_scope():
209218
scope = get_stateless_scope()
210219
value = scope.get_current_value(self)
@@ -246,30 +255,46 @@ def assign_sub(self, value):
246255

247256
@property
248257
def dtype(self):
258+
"""The data type of the variable."""
249259
autocast_scope = get_autocast_scope()
250260
if (
251261
self._autocast
252262
and autocast_scope is not None
253263
and is_float_dtype(self._dtype)
254264
):
255-
return autocast_scope.dtype
256-
return self._dtype
265+
dtype = autocast_scope.dtype
266+
else:
267+
dtype = self._dtype
268+
return backend.standardize_dtype(dtype)
257269

258270
@property
259271
def shape(self):
272+
"""The shape of the variable."""
260273
return self._shape
261274

262275
@property
263276
def ndim(self):
277+
"""The number of dimensions of the variable."""
264278
return self._ndim
265279

266280
@property
267281
def trainable(self):
282+
"""Whether the variable is trainable."""
268283
return self._trainable
269284

270285
@trainable.setter
271286
def trainable(self, value):
272-
self._trainable = value
287+
self._trainable = bool(value)
288+
289+
@property
290+
def name(self):
291+
"""The name of the variable."""
292+
return self._name
293+
294+
@property
295+
def path(self):
296+
"""The path of the variable within the Keras model or layer."""
297+
return self._path
273298

274299
@property
275300
def overwrite_with_gradient(self):
@@ -326,9 +351,13 @@ def constraint(self, value):
326351
self._constraint = value
327352

328353
def __repr__(self):
354+
value = None
355+
if hasattr(self, "_value") and self._value is not None:
356+
value = backend.core.convert_to_numpy(self._value)
357+
value_str = f", value={value}" if value is not None else ""
329358
return (
330-
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
331-
f"path={self.path}>"
359+
f"<Variable path={self.path}, shape={self.shape}, "
360+
f"dtype={self.dtype}{value_str}>"
332361
)
333362

334363
def _initialize(self, value):
@@ -573,7 +602,7 @@ def get_autocast_scope():
573602
class AutocastScope:
574603
"""Context manager that enables the autocasting of float variables.
575604
576-
Under this context manager, float `KerasVariables`s will be cast to `dtype`
605+
Under this context manager, float `Variables`s will be cast to `dtype`
577606
(note that `dtype` must also be float).
578607
"""
579608

keras/src/backend/common/variables_test.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from keras.src import initializers
99
from keras.src.backend.common import dtypes
1010
from keras.src.backend.common.variables import AutocastScope
11-
from keras.src.backend.common.variables import KerasVariable
1211
from keras.src.backend.common.variables import shape_equal
1312
from keras.src.backend.common.variables import standardize_dtype
1413
from keras.src.backend.common.variables import standardize_shape
@@ -17,7 +16,7 @@
1716

1817

1918
class VariableInitializationTest(test_case.TestCase):
20-
"""Tests for KerasVariable.__init__()"""
19+
"""Tests for Variable.__init__()"""
2120

2221
def test_deferred_initialization(self):
2322
"""Tests deferred initialization of variables."""
@@ -73,17 +72,16 @@ def test_variable_initialize(self):
7372
self.assertAllClose(v.value, init_value)
7473

7574
def test_variable_without_shape_from_callable_initializer(self):
76-
"""Test that KerasVariable raises error
75+
"""Test that Variable raises error
7776
if shape is not provided for callable initializer."""
7877
with self.assertRaisesRegex(
7978
ValueError, "When creating a Variable from an initializer"
8079
):
81-
KerasVariable(initializer=lambda: np.ones((2, 2)))
80+
backend.Variable(initializer=lambda: np.ones((2, 2)))
8281

8382

8483
class VariablePropertiesTest(test_case.TestCase):
85-
"""Tests for KerasVariable._deferred_initialize
86-
KerasVariable._maybe_autocast"""
84+
"""Tests for Variable._deferred_initialize Variable._maybe_autocast"""
8785

8886
def test_deferred_assignment(self):
8987
"""Tests deferred assignment to variables."""
@@ -204,10 +202,12 @@ def test_name_validation(self):
204202
with self.assertRaisesRegex(
205203
ValueError, "Argument `name` must be a string"
206204
):
207-
KerasVariable(initializer=initializers.RandomNormal(), name=12345)
205+
backend.Variable(
206+
initializer=initializers.RandomNormal(), name=12345
207+
)
208208

209209
with self.assertRaisesRegex(ValueError, "cannot contain character `/`"):
210-
KerasVariable(
210+
backend.Variable(
211211
initializer=initializers.RandomNormal(), name="invalid/name"
212212
)
213213

@@ -272,8 +272,7 @@ def test_overwrite_with_gradient_setter(self):
272272

273273

274274
class VariableNumpyValueAndAssignmentTest(test_case.TestCase):
275-
"""tests for KerasVariable.numpy(), KerasVariable.value()
276-
and KerasVariable.assign()"""
275+
"""tests for Variable.numpy(), Variable.value() and Variable.assign()"""
277276

278277
def test_variable_numpy(self):
279278
"""Test retrieving the value of a variable as a numpy array."""
@@ -373,10 +372,21 @@ def test_variable_repr(self):
373372
"""Test the string representation of a variable."""
374373
v = backend.Variable(initializer=np.array([1, 2, 3]), name="test_var")
375374
expected_repr = (
376-
"<KerasVariable shape=(3,), dtype=float32, path=test_var>"
375+
"<Variable path=test_var, shape=(3,), dtype=float32, "
376+
"value=[1. 2. 3.]>"
377377
)
378378
self.assertEqual(repr(v), expected_repr)
379379

380+
# Test with `backend.StatelessScope()`
381+
with backend.StatelessScope():
382+
v = backend.Variable(
383+
initializer="zeros", shape=(3,), name="test_var"
384+
)
385+
expected_repr = (
386+
"<Variable path=test_var, shape=(3,), dtype=float32>"
387+
)
388+
self.assertEqual(repr(v), expected_repr)
389+
380390
def test_variable_getitem(self):
381391
"""Test getting an item from a variable."""
382392
v = backend.Variable(initializer=np.array([1, 2, 3]))
@@ -408,7 +418,7 @@ def test_variable_array(self):
408418

409419

410420
class VariableOpsCorrectnessTest(test_case.TestCase):
411-
"""Tests for operations on KerasVariable."""
421+
"""Tests for operations on Variable."""
412422

413423
def test_int(self):
414424
v = backend.Variable(initializer=np.array(-1.1))

keras/src/backend/jax/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ def _purge_model_variables(
940940
941941
During JAX training, since the training function are stateless, we have
942942
to pass in and get the model weights over and over, during which the
943-
copy of the weights that attached to the KerasVariable are still and
943+
copy of the weights that attached to the Variable are still and
944944
occupying extra memory. We remove those variable to save memory (for
945945
better memory utilization) at the beginning of the epoch, and reattach
946946
the value back to variables at the end of the epoch, via

keras/src/backend/tensorflow/optimizer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import tensorflow as tf
1313

1414
from keras.src import backend
15-
from keras.src.backend.common import KerasVariable
1615
from keras.src.backend.tensorflow.trackable import KerasAutoTrackable
1716
from keras.src.optimizers import base_optimizer
1817

@@ -46,7 +45,7 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
4645
)
4746

4847
def assign(self, variable, value):
49-
if isinstance(variable, KerasVariable):
48+
if isinstance(variable, backend.Variable):
5049
variable = variable.value
5150
value = tf.cast(value, variable.dtype)
5251
if isinstance(value, tf.IndexedSlices):
@@ -55,7 +54,7 @@ def assign(self, variable, value):
5554
variable.assign(value)
5655

5756
def assign_add(self, variable, value):
58-
if isinstance(variable, KerasVariable):
57+
if isinstance(variable, backend.Variable):
5958
variable = variable.value
6059
value = tf.cast(value, variable.dtype)
6160
if isinstance(value, tf.IndexedSlices):
@@ -64,7 +63,7 @@ def assign_add(self, variable, value):
6463
variable.assign_add(value)
6564

6665
def assign_sub(self, variable, value):
67-
if isinstance(variable, KerasVariable):
66+
if isinstance(variable, backend.Variable):
6867
variable = variable.value
6968
value = tf.cast(value, variable.dtype)
7069
if isinstance(value, tf.IndexedSlices):

0 commit comments

Comments
 (0)