1+ from scipy .sparse .linalg import LinearOperator
2+ from sparse_dot_mkl ._mkl_interface ._cfunctions import (
3+ MKL ,
4+ mkl_library_name ,
5+ mkl_get_version_string ,
6+ mkl_get_max_threads
7+ )
8+
9+ from sparse_dot_mkl ._mkl_interface ._constants import (
10+ LAYOUT_CODE_C ,
11+ LAYOUT_CODE_F ,
12+ SPARSE_INDEX_BASE_ZERO ,
13+ RETURN_CODES ,
14+ ILP64_MSG ,
15+ SPARSE_MATRIX_TYPE_HERMITIAN ,
16+ SPARSE_MATRIX_TYPE_SYMMETRIC ,
17+ SPARSE_FILL_MODE_LOWER ,
18+ SPARSE_FILL_MODE_UPPER ,
19+ SPARSE_FILL_MODE_FULL ,
20+ SPARSE_DIAG_NON_UNIT ,
21+ SPARSE_DIAG_UNIT
22+ )
23+
24+ from sparse_dot_mkl ._mkl_interface ._structs import (
25+ sparse_matrix_t ,
26+ matrix_descr ,
27+ MKL_Complex8 ,
28+ MKL_Complex16
29+ )
30+
31+ from sparse_dot_mkl ._mkl_interface ._common import (
32+ _check_return_value ,
33+ _out_matrix ,
34+ _get_numpy_layout ,
35+ _export_mkl ,
36+ )
37+
38+ import numpy as np
39+ import scipy .sparse as sp
40+ import ctypes as _ctypes
41+
42+
43+
44+ class MKLQuantumOperator (LinearOperator ):
45+ def __init__ (self , data , indices , indptr , shape , upper = True , unit = False ):
46+ super (MKLQuantumOperator , self ).__init__ (shape = shape , dtype = np .complex128 )
47+
48+ ref = sparse_matrix_t ()
49+
50+ ret_val = MKL ._mkl_sparse_z_create_csr (
51+ _ctypes .byref (ref ),
52+ _ctypes .c_int (SPARSE_INDEX_BASE_ZERO ),
53+ MKL .MKL_INT (shape [0 ]),
54+ MKL .MKL_INT (shape [1 ]),
55+ indptr [0 :- 1 ],
56+ indptr [1 :],
57+ indices ,
58+ data ,
59+ )
60+
61+ self .indptr = indptr
62+ self .indices = indices
63+ self .data = data
64+ self .upper = upper
65+
66+ self .ref = ref
67+
68+ # Check return
69+ _check_return_value (ret_val , MKL ._mkl_sparse_z_create_csr .__name__ )
70+
71+ if unit :
72+ unit_flag = SPARSE_DIAG_UNIT
73+ else :
74+ unit_flag = SPARSE_DIAG_NON_UNIT
75+
76+ if upper :
77+ upper_flag = SPARSE_FILL_MODE_UPPER
78+ else :
79+ upper_flag = SPARSE_FILL_MODE_FULL
80+
81+
82+ self .descr = matrix_descr (
83+ sparse_matrix_type_t = SPARSE_MATRIX_TYPE_HERMITIAN ,
84+ sparse_fill_mode_t = upper_flag ,
85+ sparse_diag_type_t = unit_flag
86+ )
87+
88+ def _matvec (self , v , out = None ):
89+
90+ if not v .flags .contiguous :
91+ raise ValueError ("vector v is not contiguous" )
92+
93+ output_shape = (self .shape [0 ],) if v .ndim == 1 else (self .shape [0 ], 1 )
94+
95+ out = _out_matrix (
96+ output_shape ,
97+ np .cdouble ,
98+ out_arr = out ,
99+ out_t = False
100+ )
101+ ret_val = MKL ._mkl_sparse_z_mv (
102+ 10 ,
103+ MKL_Complex16 (1. ),
104+ self .ref ,
105+ self .descr ,
106+ v ,
107+ MKL_Complex16 (0. ),
108+ out ,
109+ )
110+
111+ _check_return_value (ret_val , MKL ._mkl_sparse_z_mv .__name__ )
112+
113+ return out
114+
115+ def _matmat (self , V , out = None ):
116+ output_shape = (self .shape [0 ], V .shape [1 ])
117+
118+ layout_b , ld_b = _get_numpy_layout (V , out )
119+
120+ output_order = "C" if layout_b == LAYOUT_CODE_C else "F"
121+
122+ out = _out_matrix (
123+ output_shape ,
124+ np .cdouble ,
125+ output_order ,
126+ out_arr = out ,
127+ out_t = False
128+ )
129+
130+ ret_val = MKL ._mkl_sparse_z_mm (
131+ 10 ,
132+ MKL_Complex16 (1. ),
133+ self .ref ,
134+ self .descr ,
135+ layout_b ,
136+ V ,
137+ output_shape [1 ],
138+ ld_b ,
139+ MKL_Complex16 (0. ),
140+ out .ctypes .data_as (_ctypes .POINTER (_ctypes .c_double )),
141+ ld_b ,
142+
143+ )
144+
145+ _check_return_value (ret_val , MKL ._mkl_sparse_z_mm .__name__ )
146+
147+ return out
148+
149+ @classmethod
150+ def from_csr (cls , csr_mat , upper = True , unit = False ):
151+ if upper :
152+ mat = sp .triu (csr_mat , format = "csr" )
153+ else :
154+ mat = csr_mat
155+
156+ return cls (
157+ mat .data ,
158+ mat .indices ,
159+ mat .indptr ,
160+ mat .shape ,
161+ upper ,
162+ unit
163+ )
164+
165+ def to_csr (self ):
166+ mat = sp .csr_matrix ((self .data , self .indices , self .indptr ), shape = self .shape , copy = False )
167+ if self .upper :
168+ mat += sp .triu (mat , k = 1 , format = "csr" ).conj ().T
169+
170+ return mat
0 commit comments