@@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool
220
220
return torch .clone (x )
221
221
return torch .amin (x , axis , keepdims = keepdims )
222
222
223
- clip = get_xp (torch )(_aliases .clip )
224
223
unstack = get_xp (torch )(_aliases .unstack )
225
224
cumulative_sum = get_xp (torch )(_aliases .cumulative_sum )
226
225
cumulative_prod = get_xp (torch )(_aliases .cumulative_prod )
@@ -808,6 +807,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
808
807
return torch .take_along_dim (x , indices , dim = axis )
809
808
810
809
810
+ def clip (
811
+ x : Array ,
812
+ / ,
813
+ min : int | float | Array | None = None ,
814
+ max : int | float | Array | None = None ,
815
+ ** kwargs
816
+ ) -> Array :
817
+ def _isscalar (a : object ):
818
+ return isinstance (a , int | float ) or a is None
819
+
820
+ # cf clip in common/_aliases.py
821
+ if not x .is_floating_point ():
822
+ if type (min ) is int and min <= torch .iinfo (x .dtype ).min :
823
+ min = None
824
+ if type (max ) is int and max >= torch .iinfo (x .dtype ).max :
825
+ max = None
826
+
827
+ if min is None and max is None :
828
+ return torch .clone (x )
829
+
830
+ min_is_scalar = _isscalar (min )
831
+ max_is_scalar = _isscalar (max )
832
+
833
+ if min is not None and max is not None :
834
+ if min_is_scalar and not max_is_scalar :
835
+ min = torch .as_tensor (min , dtype = x .dtype , device = x .device )
836
+ if max_is_scalar and not min_is_scalar :
837
+ max = torch .as_tensor (max , dtype = x .dtype , device = x .device )
838
+
839
+ return torch .clamp (x , min , max , ** kwargs )
840
+
841
+
811
842
def sign (x : Array , / ) -> Array :
812
843
# torch sign() does not support complex numbers and does not propagate
813
844
# nans. See https://github.com/data-apis/array-api-compat/issues/136
0 commit comments