Skip to content

Commit d794015

Browse files
committed
ENH: use torch.clamp for wrapped_torch.clip
Otherwise, the version which emulates "clip" fails with torch.vmap (see gh-350)
1 parent 5938c3f commit d794015

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool
220220
return torch.clone(x)
221221
return torch.amin(x, axis, keepdims=keepdims)
222222

223-
clip = get_xp(torch)(_aliases.clip)
224223
unstack = get_xp(torch)(_aliases.unstack)
225224
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
226225
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
@@ -808,6 +807,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
808807
return torch.take_along_dim(x, indices, dim=axis)
809808

810809

810+
def clip(
811+
x: Array,
812+
/,
813+
min: int | float | Array | None = None,
814+
max: int | float | Array | None = None,
815+
**kwargs
816+
) -> Array:
817+
def _isscalar(a: object):
818+
return isinstance(a, int | float) or a is None
819+
820+
# cf clip in common/_aliases.py
821+
if not x.is_floating_point():
822+
if type(min) is int and min <= torch.iinfo(x.dtype).min:
823+
min = None
824+
if type(max) is int and max >= torch.iinfo(x.dtype).max:
825+
max = None
826+
827+
if min is None and max is None:
828+
return torch.clone(x)
829+
830+
min_is_scalar = _isscalar(min)
831+
max_is_scalar = _isscalar(max)
832+
833+
if min is not None and max is not None:
834+
if min_is_scalar and not max_is_scalar:
835+
min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
836+
if max_is_scalar and not min_is_scalar:
837+
max = torch.as_tensor(max, dtype=x.dtype, device=x.device)
838+
839+
return torch.clamp(x, min, max, **kwargs)
840+
841+
811842
def sign(x: Array, /) -> Array:
812843
# torch sign() does not support complex numbers and does not propagate
813844
# nans. See https://github.com/data-apis/array-api-compat/issues/136

tests/test_torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b):
102102
torch.set_default_dtype(prev_default)
103103

104104

105+
def test_clip_vmap():
106+
# https://github.com/data-apis/array-api-compat/issues/350
107+
def apply_clip_compat(a):
108+
return xp.clip(a, min=0, max=30)
109+
110+
a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])
111+
112+
ref = apply_clip_compat(a)
113+
v1 = torch.vmap(apply_clip_compat)
114+
assert xp.all(v1(a) == ref)
115+
116+
105117
def test_meshgrid():
106118
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
107119

0 commit comments

Comments
 (0)