Skip to content

Commit 7396711

Browse files
nkovela1edward-bot
authored andcommitted
Creates non-breaking changes where necessary in preparation for switching all of Keras to new serialization format.
PiperOrigin-RevId: 507864605
1 parent 2a85212 commit 7396711

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

edward2/tensorflow/constraints.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,16 @@ def get_config(self):
7878

7979

8080
def serialize(initializer):
81-
return tf.keras.utils.serialize_keras_object(initializer)
81+
return tf.keras.utils.legacy.serialize_keras_object(initializer)
8282

8383

8484
def deserialize(config, custom_objects=None):
85-
return tf.keras.utils.deserialize_keras_object(
85+
return tf.keras.utils.legacy.deserialize_keras_object(
8686
config,
8787
module_objects=globals(),
8888
custom_objects=custom_objects,
89-
printable_module_name='constraints')
89+
printable_module_name='constraints',
90+
)
9091

9192

9293
def get(identifier, value=None):

edward2/tensorflow/initializers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,15 +851,16 @@ def get_config(self):
851851

852852

853853
def serialize(initializer):
854-
return tf.keras.utils.serialize_keras_object(initializer)
854+
return tf.keras.utils.legacy.serialize_keras_object(initializer)
855855

856856

857857
def deserialize(config, custom_objects=None):
858-
return tf.keras.utils.deserialize_keras_object(
858+
return tf.keras.utils.legacy.deserialize_keras_object(
859859
config,
860860
module_objects=globals(),
861861
custom_objects=custom_objects,
862-
printable_module_name='initializers')
862+
printable_module_name='initializers',
863+
)
863864

864865

865866
def get(identifier, value=None):

edward2/tensorflow/layers/gaussian_process.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_config(self):
104104
return {
105105
'variance': self.variance,
106106
'bias': self.bias,
107-
'encoder': tf.keras.utils.serialize_keras_object(self.encoder),
107+
'encoder': tf.keras.utils.legacy.serialize_keras_object(self.encoder),
108108
}
109109

110110

@@ -250,9 +250,10 @@ def compute_output_shape(self, input_shape):
250250
def get_config(self):
251251
config = {
252252
'units': self.units,
253-
'mean_fn': tf.keras.utils.serialize_keras_object(self.mean_fn),
254-
'covariance_fn': tf.keras.utils.serialize_keras_object(
255-
self.covariance_fn),
253+
'mean_fn': tf.keras.utils.legacy.serialize_keras_object(self.mean_fn),
254+
'covariance_fn': tf.keras.utils.legacy.serialize_keras_object(
255+
self.covariance_fn
256+
),
256257
'conditional_inputs': None, # don't serialize as it can be large
257258
'conditional_outputs': None, # don't serialize as it can be large
258259
}

edward2/tensorflow/regularizers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,15 +388,16 @@ def get_config(self):
388388

389389

390390
def serialize(initializer):
391-
return tf.keras.utils.serialize_keras_object(initializer)
391+
return tf.keras.utils.legacy.serialize_keras_object(initializer)
392392

393393

394394
def deserialize(config, custom_objects=None):
395-
return tf.keras.utils.deserialize_keras_object(
395+
return tf.keras.utils.legacy.deserialize_keras_object(
396396
config,
397397
module_objects=globals(),
398398
custom_objects=custom_objects,
399-
printable_module_name='regularizers')
399+
printable_module_name='regularizers',
400+
)
400401

401402

402403
def get(identifier, value=None):

0 commit comments

Comments
 (0)