diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index 311304ef54..cf819fb807 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -52,7 +52,7 @@ def set_torch_config(torch_settings: TorchSettings) -> None: _device = torch.device(device_str) if _device.type == "cuda": - torch.set_default_device(_device.type) + torch.set_default_device(_device) torch.set_default_dtype(torch.float32) else: torch.set_default_dtype(torch.float32)