@@ -773,8 +773,9 @@ def backend_cupy(device=None, dtype=None):
773773 ----------
774774 device : int | cupy.cuda.Device | None, optional
775775 Target CUDA device. If ``None``, use CuPy's current device.
776- dtype : dtype-like | None, optional
777- Target CuPy dtype. If ``None``, infer from input.
776+ dtype : dtype-like | torch.dtype | None, optional
777+ Target CuPy dtype. If ``None``, infer from input. Torch dtypes are
778+ accepted and internally mapped to CuPy-compatible dtypes.
778779 """
779780 try :
780781 import cupy as cp # pylint: disable=import-outside-toplevel
@@ -790,6 +791,26 @@ def backend_cupy(device=None, dtype=None):
790791 if isinstance (target_device , int ):
791792 target_device = cp .cuda .Device (target_device )
792793
794+ if torch is not None and isinstance (dtype , torch .dtype ):
795+ torch_to_cupy = {
796+ torch .complex128 : cp .complex128 ,
797+ torch .complex64 : cp .complex64 ,
798+ torch .float64 : cp .float64 ,
799+ torch .float32 : cp .float32 ,
800+ torch .float16 : cp .float16 ,
801+ torch .int64 : cp .int64 ,
802+ torch .int32 : cp .int32 ,
803+ torch .int16 : cp .int16 ,
804+ torch .int8 : cp .int8 ,
805+ torch .uint8 : cp .uint8 ,
806+ torch .bool : cp .bool_ ,
807+ }
808+ if dtype not in torch_to_cupy :
809+ raise ValueError (
810+ f"backend_cupy does not support torch dtype { dtype !r} ."
811+ )
812+ dtype = torch_to_cupy [dtype ]
813+
793814 def cast_array (x , device = target_device , dtype = dtype ):
794815 if device is None :
795816 return cp .asarray (x , dtype = dtype )
0 commit comments