Skip to content

Commit 819842b

Browse files
authored
Merge pull request #174 from asmeurer/torch-conj
Use conj_physical for torch.conj
2 parents f905d8c + b754016 commit 819842b

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

array_api_compat/torch/_aliases.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
145145
# Basic renames
146146
bitwise_invert = torch.bitwise_not
147147
newaxis = None
148+
# torch.conj sets the conjugation bit, which breaks conversion to other
149+
# libraries. See https://github.com/data-apis/array-api-compat/issues/173
150+
conj = torch.conj_physical
148151

149152
# Two-arg elementwise functions
150153
# These require a wrapper to do the correct type promotion on 0-D tensors
@@ -704,18 +707,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
704707
return torch.index_select(x, axis, indices, **kwargs)
705708

706709
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
707-
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
708-
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
709-
'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
710-
'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
711-
'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod',
712-
'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
713-
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
714-
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
715-
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
716-
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
717-
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
718-
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
719-
'take']
710+
'newaxis', 'conj', 'add', 'atan2', 'bitwise_and',
711+
'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift',
712+
'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide',
713+
'greater', 'greater_equal', 'less', 'less_equal', 'logaddexp',
714+
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
715+
'min', 'clip', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std',
716+
'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
717+
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
718+
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
719+
'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult',
720+
'UniqueInverseResult', 'unique_all', 'unique_counts',
721+
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
722+
'vecdot', 'tensordot', 'isdtype', 'take']
720723

721724
_all_ignore = ['torch', 'get_xp']

0 commit comments

Comments
 (0)