@@ -145,6 +145,9 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
145
145
# Basic renames
146
146
bitwise_invert = torch .bitwise_not
147
147
newaxis = None
148
+ # torch.conj sets the conjugation bit, which breaks conversion to other
149
+ # libraries. See https://github.com/data-apis/array-api-compat/issues/173
150
+ conj = torch .conj_physical
148
151
149
152
# Two-arg elementwise functions
150
153
# These require a wrapper to do the correct type promotion on 0-D tensors
@@ -704,18 +707,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
704
707
return torch .index_select (x , axis , indices , ** kwargs )
705
708
706
709
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' ,
707
- 'newaxis' , 'add ' , 'atan2 ' , 'bitwise_and ' , 'bitwise_left_shift ' ,
708
- 'bitwise_or ' , 'bitwise_right_shift ' , 'bitwise_xor' , 'copysign ' ,
709
- 'divide ' , 'equal ' , 'floor_divide ' , 'greater ' , 'greater_equal ' ,
710
- 'less ' , 'less_equal ' , 'logaddexp ' , 'multiply ' , 'not_equal' , 'pow ' ,
711
- 'remainder ' , 'subtract ' , 'max ' , 'min ' , 'clip ' , 'sort' , 'prod ' ,
712
- 'sum ' , 'any ' , 'all ' , 'mean ' , 'std ' , 'var ' , 'concat ' , 'squeeze ' ,
713
- 'broadcast_to ' , 'flip ' , 'roll ' , 'nonzero ' , 'where ' , 'reshape ' ,
714
- 'arange ' , 'eye ' , 'linspace ' , 'full ' , 'ones ' , 'zeros ' , 'empty ' ,
715
- 'tril ' , 'triu ' , 'expand_dims ' , 'astype ' , 'broadcast_arrays ' ,
716
- 'UniqueAllResult ' , 'UniqueCountsResult ' , 'UniqueInverseResult ' ,
717
- 'unique_all ' , 'unique_counts ' , 'unique_inverse' , 'unique_values ' ,
718
- 'matmul ' , 'matrix_transpose ' , 'vecdot ' , 'tensordot' , 'isdtype ' ,
719
- 'take' ]
710
+ 'newaxis' , 'conj ' , 'add ' , 'atan2 ' , 'bitwise_and ' ,
711
+ 'bitwise_left_shift ' , 'bitwise_or ' , 'bitwise_right_shift ' ,
712
+ 'bitwise_xor ' , 'copysign ' , 'divide ' , 'equal ' , 'floor_divide ' ,
713
+ 'greater ' , 'greater_equal ' , 'less ' , 'less_equal ' , 'logaddexp ' ,
714
+ 'multiply ' , 'not_equal ' , 'pow ' , 'remainder ' , 'subtract ' , 'max ' ,
715
+ 'min ' , 'clip ' , 'sort ' , 'prod ' , 'sum ' , 'any ' , 'all ' , 'mean' , 'std ' ,
716
+ 'var ' , 'concat ' , 'squeeze ' , 'broadcast_to ' , 'flip ' , 'roll ' ,
717
+ 'nonzero ' , 'where ' , 'reshape ' , 'arange ' , 'eye ' , 'linspace ' , 'full ' ,
718
+ 'ones ' , 'zeros ' , 'empty ' , 'tril ' , 'triu' , 'expand_dims' , 'astype ' ,
719
+ 'broadcast_arrays ' , 'UniqueAllResult ' , 'UniqueCountsResult ' ,
720
+ 'UniqueInverseResult ' , 'unique_all ' , 'unique_counts ' ,
721
+ 'unique_inverse ' , 'unique_values ' , 'matmul ' , 'matrix_transpose ' ,
722
+ 'vecdot' , 'tensordot' , 'isdtype' , ' take' ]
720
723
721
724
_all_ignore = ['torch' , 'get_xp' ]
0 commit comments