From 3d62d3872be6d72ffbc4a3b7005bbe3f0270f7f9 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 10:14:01 +0100 Subject: [PATCH] ENH: torch.asarray device propagation --- array_api_compat/torch/_aliases.py | 31 ++++++++++++++++++++++++------ array_api_compat/torch/_typing.py | 5 ++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..0891525a 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,12 +2,13 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from .._internal import get_xp from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: ( + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + ), + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. @@ -282,7 +305,6 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic @@ -318,7 +340,6 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. @@ -348,7 +369,6 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -373,7 +393,6 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py index 29ad3fa7..52670871 100644 --- a/array_api_compat/torch/_typing.py +++ b/array_api_compat/torch/_typing.py @@ -1,4 +1,3 @@ -__all__ = ["Array", "DType", "Device"] +__all__ = ["Array", "Device", "DType"] -from torch import dtype as DType, Tensor as Array -from ..common._typing import Device +from torch import device as Device, dtype as DType, Tensor as Array