From 7d61b89b5aa7e131adf1ed98d59499ef03a8ec11 Mon Sep 17 00:00:00 2001 From: Tewodros Deneke Date: Tue, 21 Dec 2021 15:31:17 +0200 Subject: [PATCH 1/2] add conditional gpu check and useage --- tf2_imdb.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tf2_imdb.py b/tf2_imdb.py index 8276cf1..95d1249 100644 --- a/tf2_imdb.py +++ b/tf2_imdb.py @@ -13,7 +13,15 @@ print('Using Tensorflow version: {}, and Keras version: {}.'.format( tf.__version__, tf.keras.__version__)) +print(tf.config.get_visible_devices()) +# create a distribution strategy +if tf.config.list_physical_devices('GPU'): + strategy = tf.distribute.MirroredStrategy() +else: # a default fallback strategy + strategy = tf.distribute.get_strategy() + +print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) class DetectSentiment: def __init__(self): @@ -39,12 +47,14 @@ def predict(self, text): if idx >= self.nb_words: idx = oov_idx v[0, i+1] = idx + + with strategy.scope(): + p = self.model.predict(v, batch_size=1) - p = self.model.predict(v, batch_size=1) return float(p[0, 0]) if __name__ == '__main__': ds = DetectSentiment() text = ' '.join(sys.argv[1:]) - print('Prediction for "{}": {}'.format(text, ds.predict(text))) + print('Prediction for "{}": {}'.format(text, ds.predict(text))) \ No newline at end of file From b847760a741599eaca80287fc715a54339676e21 Mon Sep 17 00:00:00 2001 From: Tewodros Deneke Date: Tue, 21 Dec 2021 15:41:03 +0200 Subject: [PATCH 2/2] fix white spaces --- tf2_imdb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf2_imdb.py b/tf2_imdb.py index 95d1249..1ffbe63 100644 --- a/tf2_imdb.py +++ b/tf2_imdb.py @@ -47,7 +47,7 @@ def predict(self, text): if idx >= self.nb_words: idx = oov_idx v[0, i+1] = idx - + with strategy.scope(): p = self.model.predict(v, batch_size=1) @@ -57,4 +57,4 @@ def predict(self, text): if __name__ == '__main__': ds = DetectSentiment() text = ' '.join(sys.argv[1:]) - print('Prediction for "{}": {}'.format(text, ds.predict(text))) \ No newline at end of file + print('Prediction for "{}": {}'.format(text, ds.predict(text)))