1212import random
1313import tempfile
1414import uuid
15- from typing import Any , Callable , Dict , List , Optional , Tuple , Type
15+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
1616from unittest .mock import Mock , patch
1717
1818import torch
@@ -45,7 +45,9 @@ def gen_test_batch(
4545 mask : Optional [torch .Tensor ] = None ,
4646 n_classes : Optional [int ] = None ,
4747 seed : Optional [int ] = None ,
48+ device : Optional [Union [str , torch .device ]] = None ,
4849) -> Dict [str , torch .Tensor ]:
50+ device = torch .device (device or "cpu" )
4951 if seed is not None :
5052 torch .manual_seed (seed )
5153 if label_value is not None :
@@ -65,14 +67,14 @@ def gen_test_batch(
6567 else :
6668 weight = torch .rand (batch_size , dtype = torch .double )
6769 test_batch = {
68- label_name : label ,
69- prediction_name : prediction ,
70- weight_name : weight ,
71- tensor_name : torch .rand (batch_size , dtype = torch .double ),
70+ label_name : label . to ( device ) ,
71+ prediction_name : prediction . to ( device ) ,
72+ weight_name : weight . to ( device ) ,
73+ tensor_name : torch .rand (batch_size , dtype = torch .double ). to ( device ) ,
7274 }
7375 if mask_tensor_name is not None :
7476 if mask is None :
75- mask = torch .ones (batch_size , dtype = torch .double )
77+ mask = torch .ones (batch_size , dtype = torch .double ). to ( device )
7678 test_batch [mask_tensor_name ] = mask
7779
7880 return test_batch
@@ -240,8 +242,10 @@ def rec_metric_value_test_helper(
240242 n_classes : Optional [int ] = None ,
241243 zero_weights : bool = False ,
242244 zero_labels : bool = False ,
245+ device : Optional [Union [str , torch .device ]] = None ,
243246 ** kwargs : Any ,
244247) -> Tuple [Dict [str , torch .Tensor ], Tuple [Dict [str , torch .Tensor ], ...]]:
248+ device = torch .device (device or "cpu" )
245249 tasks = gen_test_tasks (task_names )
246250 model_outs = []
247251 for _ in range (nsteps ):
@@ -263,6 +267,7 @@ def rec_metric_value_test_helper(
263267 n_classes = n_classes ,
264268 weight_value = weight_value ,
265269 label_value = label_value ,
270+ device = device ,
266271 )
267272 for task in tasks
268273 ]
@@ -293,7 +298,8 @@ def get_target_rec_metric_value(
293298 compute_on_all_ranks = compute_on_all_ranks ,
294299 should_validate_update = should_validate_update ,
295300 ** kwargs ,
296- )
301+ ).to (device )
302+
297303 for i in range (nsteps ):
298304 # Get required_inputs_list from the target metric
299305 required_inputs_list = list (target_metric_obj .get_required_inputs ())
@@ -381,6 +387,7 @@ def rec_metric_gpu_sync_test_launcher(
381387 entry_point : Callable [..., None ],
382388 batch_size : int = BATCH_SIZE ,
383389 batch_window_size : int = BATCH_WINDOW_SIZE ,
390+ device : Optional [Union [str , torch .device ]] = None ,
384391 ** kwargs : Dict [str , Any ],
385392) -> None :
386393 with tempfile .TemporaryDirectory () as tmpdir :
@@ -402,6 +409,8 @@ def rec_metric_gpu_sync_test_launcher(
402409 batch_size ,
403410 batch_window_size ,
404411 kwargs .get ("n_classes" , None ),
412+ False ,
413+ torch .device (device or "cpu" ),
405414 )
406415
407416
@@ -419,8 +428,10 @@ def sync_test_helper(
419428 batch_window_size : int = BATCH_WINDOW_SIZE ,
420429 n_classes : Optional [int ] = None ,
421430 zero_weights : bool = False ,
431+ device : Optional [Union [str , torch .device ]] = None ,
422432 ** kwargs : Dict [str , Any ],
423433) -> None :
434+ device = torch .device (device or "cpu" )
424435 rank = int (os .environ ["RANK" ])
425436 world_size = int (os .environ ["WORLD_SIZE" ])
426437 dist .init_process_group (
@@ -444,7 +455,7 @@ def sync_test_helper(
444455 window_size = batch_window_size * world_size ,
445456 # pyre-ignore[6]: Incompatible parameter type
446457 ** kwargs ,
447- )
458+ ). to ( device )
448459
449460 weight_value : Optional [torch .Tensor ] = None
450461
@@ -458,6 +469,7 @@ def sync_test_helper(
458469 n_classes = n_classes ,
459470 weight_value = weight_value ,
460471 seed = 42 , # we set seed because of how test metric places tensors on ranks
472+ device = device ,
461473 )
462474 for task in tasks
463475 ]
@@ -575,6 +587,7 @@ def rec_metric_value_test_launcher(
575587 n_classes : Optional [int ] = None ,
576588 zero_weights : bool = False ,
577589 zero_labels : bool = False ,
590+ device : Optional [Union [str , torch .device ]] = None ,
578591 ** kwargs : Any ,
579592) -> None :
580593 with tempfile .TemporaryDirectory () as tmpdir :
@@ -600,6 +613,7 @@ def rec_metric_value_test_launcher(
600613 n_classes = n_classes ,
601614 zero_weights = zero_weights ,
602615 zero_labels = zero_labels ,
616+ device = device ,
603617 ** kwargs ,
604618 )
605619
@@ -616,6 +630,7 @@ def rec_metric_value_test_launcher(
616630 n_classes ,
617631 test_nsteps ,
618632 zero_weights ,
633+ device ,
619634 )
620635
621636
@@ -642,6 +657,7 @@ def metric_test_helper(
642657 n_classes : Optional [int ] = None ,
643658 nsteps : int = 1 ,
644659 zero_weights : bool = False ,
660+ device : Optional [Union [str , torch .device ]] = None ,
645661 is_time_dependent : bool = False ,
646662 time_dependent_metric : Optional [Dict [Type [RecMetric ], str ]] = None ,
647663 ** kwargs : Any ,
@@ -670,6 +686,7 @@ def metric_test_helper(
670686 is_time_dependent = is_time_dependent ,
671687 time_dependent_metric = time_dependent_metric ,
672688 zero_weights = zero_weights ,
689+ device = device ,
673690 ** kwargs ,
674691 )
675692
0 commit comments