12
12
import random
13
13
import tempfile
14
14
import uuid
15
- from typing import Any , Callable , Dict , List , Optional , Tuple , Type
15
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
16
16
from unittest .mock import Mock , patch
17
17
18
18
import torch
@@ -45,7 +45,9 @@ def gen_test_batch(
45
45
mask : Optional [torch .Tensor ] = None ,
46
46
n_classes : Optional [int ] = None ,
47
47
seed : Optional [int ] = None ,
48
+ device : Optional [Union [str , torch .device ]] = None ,
48
49
) -> Dict [str , torch .Tensor ]:
50
+ device = torch .device (device or "cpu" )
49
51
if seed is not None :
50
52
torch .manual_seed (seed )
51
53
if label_value is not None :
@@ -65,14 +67,14 @@ def gen_test_batch(
65
67
else :
66
68
weight = torch .rand (batch_size , dtype = torch .double )
67
69
test_batch = {
68
- label_name : label ,
69
- prediction_name : prediction ,
70
- weight_name : weight ,
71
- tensor_name : torch .rand (batch_size , dtype = torch .double ),
70
+ label_name : label . to ( device ) ,
71
+ prediction_name : prediction . to ( device ) ,
72
+ weight_name : weight . to ( device ) ,
73
+ tensor_name : torch .rand (batch_size , dtype = torch .double ). to ( device ) ,
72
74
}
73
75
if mask_tensor_name is not None :
74
76
if mask is None :
75
- mask = torch .ones (batch_size , dtype = torch .double )
77
+ mask = torch .ones (batch_size , dtype = torch .double ). to ( device )
76
78
test_batch [mask_tensor_name ] = mask
77
79
78
80
return test_batch
@@ -240,8 +242,10 @@ def rec_metric_value_test_helper(
240
242
n_classes : Optional [int ] = None ,
241
243
zero_weights : bool = False ,
242
244
zero_labels : bool = False ,
245
+ device : Optional [Union [str , torch .device ]] = None ,
243
246
** kwargs : Any ,
244
247
) -> Tuple [Dict [str , torch .Tensor ], Tuple [Dict [str , torch .Tensor ], ...]]:
248
+ device = torch .device (device or "cpu" )
245
249
tasks = gen_test_tasks (task_names )
246
250
model_outs = []
247
251
for _ in range (nsteps ):
@@ -263,6 +267,7 @@ def rec_metric_value_test_helper(
263
267
n_classes = n_classes ,
264
268
weight_value = weight_value ,
265
269
label_value = label_value ,
270
+ device = device ,
266
271
)
267
272
for task in tasks
268
273
]
@@ -293,7 +298,8 @@ def get_target_rec_metric_value(
293
298
compute_on_all_ranks = compute_on_all_ranks ,
294
299
should_validate_update = should_validate_update ,
295
300
** kwargs ,
296
- )
301
+ ).to (device )
302
+
297
303
for i in range (nsteps ):
298
304
# Get required_inputs_list from the target metric
299
305
required_inputs_list = list (target_metric_obj .get_required_inputs ())
@@ -381,6 +387,7 @@ def rec_metric_gpu_sync_test_launcher(
381
387
entry_point : Callable [..., None ],
382
388
batch_size : int = BATCH_SIZE ,
383
389
batch_window_size : int = BATCH_WINDOW_SIZE ,
390
+ device : Optional [Union [str , torch .device ]] = None ,
384
391
** kwargs : Dict [str , Any ],
385
392
) -> None :
386
393
with tempfile .TemporaryDirectory () as tmpdir :
@@ -402,6 +409,8 @@ def rec_metric_gpu_sync_test_launcher(
402
409
batch_size ,
403
410
batch_window_size ,
404
411
kwargs .get ("n_classes" , None ),
412
+ False ,
413
+ torch .device (device or "cpu" ),
405
414
)
406
415
407
416
@@ -419,8 +428,10 @@ def sync_test_helper(
419
428
batch_window_size : int = BATCH_WINDOW_SIZE ,
420
429
n_classes : Optional [int ] = None ,
421
430
zero_weights : bool = False ,
431
+ device : Optional [Union [str , torch .device ]] = None ,
422
432
** kwargs : Dict [str , Any ],
423
433
) -> None :
434
+ device = torch .device (device or "cpu" )
424
435
rank = int (os .environ ["RANK" ])
425
436
world_size = int (os .environ ["WORLD_SIZE" ])
426
437
dist .init_process_group (
@@ -444,7 +455,7 @@ def sync_test_helper(
444
455
window_size = batch_window_size * world_size ,
445
456
# pyre-ignore[6]: Incompatible parameter type
446
457
** kwargs ,
447
- )
458
+ ). to ( device )
448
459
449
460
weight_value : Optional [torch .Tensor ] = None
450
461
@@ -458,6 +469,7 @@ def sync_test_helper(
458
469
n_classes = n_classes ,
459
470
weight_value = weight_value ,
460
471
seed = 42 , # we set seed because of how test metric places tensors on ranks
472
+ device = device ,
461
473
)
462
474
for task in tasks
463
475
]
@@ -575,6 +587,7 @@ def rec_metric_value_test_launcher(
575
587
n_classes : Optional [int ] = None ,
576
588
zero_weights : bool = False ,
577
589
zero_labels : bool = False ,
590
+ device : Optional [Union [str , torch .device ]] = None ,
578
591
** kwargs : Any ,
579
592
) -> None :
580
593
with tempfile .TemporaryDirectory () as tmpdir :
@@ -600,6 +613,7 @@ def rec_metric_value_test_launcher(
600
613
n_classes = n_classes ,
601
614
zero_weights = zero_weights ,
602
615
zero_labels = zero_labels ,
616
+ device = device ,
603
617
** kwargs ,
604
618
)
605
619
@@ -616,6 +630,7 @@ def rec_metric_value_test_launcher(
616
630
n_classes ,
617
631
test_nsteps ,
618
632
zero_weights ,
633
+ device ,
619
634
)
620
635
621
636
@@ -642,6 +657,7 @@ def metric_test_helper(
642
657
n_classes : Optional [int ] = None ,
643
658
nsteps : int = 1 ,
644
659
zero_weights : bool = False ,
660
+ device : Optional [Union [str , torch .device ]] = None ,
645
661
is_time_dependent : bool = False ,
646
662
time_dependent_metric : Optional [Dict [Type [RecMetric ], str ]] = None ,
647
663
** kwargs : Any ,
@@ -670,6 +686,7 @@ def metric_test_helper(
670
686
is_time_dependent = is_time_dependent ,
671
687
time_dependent_metric = time_dependent_metric ,
672
688
zero_weights = zero_weights ,
689
+ device = device ,
673
690
** kwargs ,
674
691
)
675
692
0 commit comments