Skip to content

Commit 4c004d4

Browse files
authored
Adjust demo scripts to be Keras 3 compatible (#6761)
This adjusts *_demo.py files to work with Keras 3. The hparams_demo is fully backward compatible with Keras 2 and forward compatible with Keras 3. Unfortunately the graphs_demo is not backward compatible with Keras 2. Users attempting to run it with Keras 2 will get the following error: ``` File "/usr/local/google/home/bdubois/.cache/bazel/_bazel_bdubois/079646a57be11faea0b2bfefccb2a81a/execroot/org_tensorflow_tensorboard/bazel-out/k8-fastbuild/bin/tensorboard/plugins/graph/graphs_demo.runfiles/org_tensorflow_tensorboard/tensorboard/plugins/graph/graphs_demo.py", line 128, in profile tf.summary.trace_on(profiler=True, profiler_outdir=logdir) TypeError: trace_on() got an unexpected keyword argument 'profiler_outdir' ``` Amazingly, though, the graph that the demo generates with Keras 3 can be successfully loaded in the Graph dashboard. This makes me optimistic to get the Graph plugin fully Keras 3 compatible after addressing the user-reported error in #6686. Old Keras 2 Graph: ![image](https://github.com/tensorflow/tensorboard/assets/17152369/b8745739-ac06-4171-a7bc-c97135b2dec7) New Keras 3 Graph: ![image](https://github.com/tensorflow/tensorboard/assets/17152369/04dcaf14-3464-47dc-b4a5-373cccd8370b)
1 parent 8a99668 commit 4c004d4

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

tensorboard/plugins/graph/graphs_demo.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def keras():
9090
model = tf.keras.models.Sequential(layers)
9191
model.compile(
9292
loss=tf.keras.losses.mean_squared_error,
93-
optimizer=tf.keras.optimizers.SGD(lr=0.2),
93+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.2),
9494
)
9595
model.fit(
9696
x_train,
@@ -125,11 +125,9 @@ def g(i):
125125
for step in range(3):
126126
# Suppress the profiler deprecation warnings from tf.summary.trace_*.
127127
with _silence_deprecation_warnings():
128-
tf.summary.trace_on(profiler=True)
128+
tf.summary.trace_on(profiler=True, profiler_outdir=logdir)
129129
print(f(tf.constant(step)).numpy())
130-
tf.summary.trace_export(
131-
"prof_f", step=step, profiler_outdir=logdir
132-
)
130+
tf.summary.trace_export("prof_f", step=step)
133131

134132
tf.summary.trace_on(profiler=False)
135133
print(g(tf.constant(step)).numpy())

tensorboard/plugins/hparams/hparams_demo.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def model_fn(hparams, seed):
137137
conv_filters *= 2
138138

139139
model.add(tf.keras.layers.Flatten())
140-
model.add(tf.keras.layers.Dropout(hparams[HP_DROPOUT], seed=rng.random()))
140+
model.add(
141+
tf.keras.layers.Dropout(
142+
hparams[HP_DROPOUT], seed=rng.randrange(1 << 32)
143+
)
144+
)
141145

142146
# Add fully connected layers.
143147
dense_neurons = 32

0 commit comments

Comments
 (0)