@@ -104,10 +104,7 @@ def main():
104104
105105 args .distributed = args .world_size > 1 or args .multiprocessing_distributed
106106
107- if torch .cuda .is_available ():
108- ngpus_per_node = torch .cuda .device_count ()
109- else :
110- ngpus_per_node = 1
107+ ngpus_per_node = torch .cuda .device_count ()
111108 if args .multiprocessing_distributed :
112109 # Since we have ngpus_per_node processes per node, the total world_size
113110 # needs to be adjusted accordingly
@@ -144,33 +141,29 @@ def main_worker(gpu, ngpus_per_node, args):
144141 print ("=> creating model '{}'" .format (args .arch ))
145142 model = models .__dict__ [args .arch ]()
146143
147- if not torch .cuda .is_available () and not torch . backends . mps . is_available () :
144+ if not torch .cuda .is_available ():
148145 print ('using CPU, this will be slow' )
149146 elif args .distributed :
150147 # For multiprocessing distributed, DistributedDataParallel constructor
151148 # should always set the single device scope, otherwise,
152149 # DistributedDataParallel will use all available devices.
153- if torch .cuda .is_available ():
154- if args .gpu is not None :
155- torch .cuda .set_device (args .gpu )
156- model .cuda (args .gpu )
157- # When using a single GPU per process and per
158- # DistributedDataParallel, we need to divide the batch size
159- # ourselves based on the total number of GPUs of the current node.
160- args .batch_size = int (args .batch_size / ngpus_per_node )
161- args .workers = int ((args .workers + ngpus_per_node - 1 ) / ngpus_per_node )
162- model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
163- else :
164- model .cuda ()
165- # DistributedDataParallel will divide and allocate batch_size to all
166- # available GPUs if device_ids are not set
167- model = torch .nn .parallel .DistributedDataParallel (model )
168- elif args .gpu is not None and torch .cuda .is_available ():
150+ if args .gpu is not None :
151+ torch .cuda .set_device (args .gpu )
152+ model .cuda (args .gpu )
153+ # When using a single GPU per process and per
154+ # DistributedDataParallel, we need to divide the batch size
155+ # ourselves based on the total number of GPUs of the current node.
156+ args .batch_size = int (args .batch_size / ngpus_per_node )
157+ args .workers = int ((args .workers + ngpus_per_node - 1 ) / ngpus_per_node )
158+ model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
159+ else :
160+ model .cuda ()
161+ # DistributedDataParallel will divide and allocate batch_size to all
162+ # available GPUs if device_ids are not set
163+ model = torch .nn .parallel .DistributedDataParallel (model )
164+ elif args .gpu is not None :
169165 torch .cuda .set_device (args .gpu )
170166 model = model .cuda (args .gpu )
171- elif torch .backends .mps .is_available ():
172- device = torch .device ("mps" )
173- model = model .to (device )
174167 else :
175168 # DataParallel will divide and allocate batch_size to all available GPUs
176169 if args .arch .startswith ('alexnet' ) or args .arch .startswith ('vgg' ):
@@ -179,17 +172,8 @@ def main_worker(gpu, ngpus_per_node, args):
179172 else :
180173 model = torch .nn .DataParallel (model ).cuda ()
181174
182- if torch .cuda .is_available ():
183- if args .gpu :
184- device = torch .device ('cuda:{}' .format (args .gpu ))
185- else :
186- device = torch .device ("cuda" )
187- elif torch .backends .mps .is_available ():
188- device = torch .device ("mps" )
189- else :
190- device = torch .device ("cpu" )
191175 # define loss function (criterion), optimizer, and learning rate scheduler
192- criterion = nn .CrossEntropyLoss ().to ( device )
176+ criterion = nn .CrossEntropyLoss ().cuda ( args . gpu )
193177
194178 optimizer = torch .optim .SGD (model .parameters (), args .lr ,
195179 momentum = args .momentum ,
@@ -204,7 +188,7 @@ def main_worker(gpu, ngpus_per_node, args):
204188 print ("=> loading checkpoint '{}'" .format (args .resume ))
205189 if args .gpu is None :
206190 checkpoint = torch .load (args .resume )
207- elif torch . cuda . is_available () :
191+ else :
208192 # Map model to be loaded to specified single gpu.
209193 loc = 'cuda:{}' .format (args .gpu )
210194 checkpoint = torch .load (args .resume , map_location = loc )
@@ -318,13 +302,10 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
318302 # measure data loading time
319303 data_time .update (time .time () - end )
320304
321- if args .gpu is not None and torch . cuda . is_available () :
305+ if args .gpu is not None :
322306 images = images .cuda (args .gpu , non_blocking = True )
323- elif not args . gpu and torch .cuda .is_available ():
307+ if torch .cuda .is_available ():
324308 target = target .cuda (args .gpu , non_blocking = True )
325- elif torch .backends .mps .is_available ():
326- images = images .to ('mps' )
327- target = target .to ('mps' )
328309
329310 # compute output
330311 output = model (images )
@@ -356,11 +337,8 @@ def run_validate(loader, base_progress=0):
356337 end = time .time ()
357338 for i , (images , target ) in enumerate (loader ):
358339 i = base_progress + i
359- if args .gpu is not None and torch . cuda . is_available () :
340+ if args .gpu is not None :
360341 images = images .cuda (args .gpu , non_blocking = True )
361- if torch .backends .mps .is_available ():
362- images = images .to ('mps' )
363- target = target .to ('mps' )
364342 if torch .cuda .is_available ():
365343 target = target .cuda (args .gpu , non_blocking = True )
366344
@@ -443,12 +421,7 @@ def update(self, val, n=1):
443421 self .avg = self .sum / self .count
444422
445423 def all_reduce (self ):
446- if torch .cuda .is_available ():
447- device = torch .device ("cuda" )
448- elif torch .backends .mps .is_available ():
449- device = torch .device ("mps" )
450- else :
451- device = torch .device ("cpu" )
424+ device = "cuda" if torch .cuda .is_available () else "cpu"
452425 total = torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = device )
453426 dist .all_reduce (total , dist .ReduceOp .SUM , async_op = False )
454427 self .sum , self .count = total .tolist ()
0 commit comments