Skip to content

Commit 1d4e8ac

Browse files
authored
Merge pull request #1465 from hansoli68/MultiGPU
Multi GPU arguments fixes
2 parents a51bcff + c05643d commit 1d4e8ac

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

keras_retinanet/bin/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def parse_args(args):
9999
parser.add_argument('model', help='Path to RetinaNet model.')
100100
parser.add_argument('--convert-model', help='Convert the model to an inference model (ie. the input is a training model).', action='store_true')
101101
parser.add_argument('--backbone', help='The backbone of the model.', default='resnet50')
102-
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).', type=int)
102+
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
103103
parser.add_argument('--score-threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.05, type=float)
104104
parser.add_argument('--iou-threshold', help='IoU Threshold to count for a positive detection (defaults to 0.5).', default=0.5, type=float)
105105
parser.add_argument('--max-detections', help='Max Detections per image (defaults to 100).', default=100, type=int)

keras_retinanet/bin/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def csv_list(string):
426426
group.add_argument('--no-weights', help='Don\'t initialize the model with any weights.', dest='imagenet_weights', action='store_const', const=False)
427427
parser.add_argument('--backbone', help='Backbone model used by retinanet.', default='resnet50', type=str)
428428
parser.add_argument('--batch-size', help='Size of the batches.', default=1, type=int)
429-
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).', type=int)
429+
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
430430
parser.add_argument('--multi-gpu', help='Number of GPUs to use for parallel processing.', type=int, default=0)
431431
parser.add_argument('--multi-gpu-force', help='Extra flag needed to enable (experimental) multi-gpu support.', action='store_true')
432432
parser.add_argument('--initial-epoch', help='Epoch from which to begin the train, useful if resuming from snapshot.', type=int, default=0)

keras_retinanet/utils/gpu.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@
1818

1919

2020
def setup_gpu(gpu_id):
21-
if gpu_id == 'cpu' or gpu_id == -1:
21+
try:
22+
visible_gpu_indices = [int(id) for id in gpu_id.split(',')]
23+
available_gpus = tf.config.list_physical_devices('GPU')
24+
visible_gpus = [gpu for idx, gpu in enumerate(available_gpus) if idx in visible_gpu_indices]
25+
26+
if visible_gpus:
27+
try:
28+
# Currently, memory growth needs to be the same across GPUs.
29+
for gpu in available_gpus:
30+
tf.config.experimental.set_memory_growth(gpu, True)
31+
32+
# Use only the selcted gpu.
33+
tf.config.set_visible_devices(visible_gpus, 'GPU')
34+
except RuntimeError as e:
35+
# Visible devices must be set before GPUs have been initialized.
36+
print(e)
37+
38+
logical_gpus = tf.config.list_logical_devices('GPU')
39+
print(len(available_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
40+
else:
41+
tf.config.set_visible_devices([], 'GPU')
42+
except ValueError:
2243
tf.config.set_visible_devices([], 'GPU')
23-
return
24-
25-
gpus = tf.config.list_physical_devices('GPU')
26-
if gpus:
27-
# Restrict TensorFlow to only use the first GPU.
28-
try:
29-
# Currently, memory growth needs to be the same across GPUs.
30-
for gpu in gpus:
31-
tf.config.experimental.set_memory_growth(gpu, True)
32-
33-
# Use only the selcted gpu.
34-
tf.config.set_visible_devices(gpus[int(gpu_id)], 'GPU')
35-
except RuntimeError as e:
36-
# Visible devices must be set before GPUs have been initialized.
37-
print(e)
38-
39-
logical_gpus = tf.config.list_logical_devices('GPU')
40-
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")

0 commit comments

Comments
 (0)