Skip to content

Commit 55ba23c

Browse files
Merge pull request #17 from floatingCatty/mpi_tools
add mkl sparse operator and MPI tools class
2 parents a7a611d + 013ff34 commit 55ba23c

File tree

2 files changed

+717
-0
lines changed

2 files changed

+717
-0
lines changed

dpnegf/utils/mkl_operator.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from scipy.sparse.linalg import LinearOperator
2+
from sparse_dot_mkl._mkl_interface._cfunctions import (
3+
MKL,
4+
mkl_library_name,
5+
mkl_get_version_string,
6+
mkl_get_max_threads
7+
)
8+
9+
from sparse_dot_mkl._mkl_interface._constants import (
10+
LAYOUT_CODE_C,
11+
LAYOUT_CODE_F,
12+
SPARSE_INDEX_BASE_ZERO,
13+
RETURN_CODES,
14+
ILP64_MSG,
15+
SPARSE_MATRIX_TYPE_HERMITIAN,
16+
SPARSE_MATRIX_TYPE_SYMMETRIC,
17+
SPARSE_FILL_MODE_LOWER,
18+
SPARSE_FILL_MODE_UPPER,
19+
SPARSE_FILL_MODE_FULL,
20+
SPARSE_DIAG_NON_UNIT,
21+
SPARSE_DIAG_UNIT
22+
)
23+
24+
from sparse_dot_mkl._mkl_interface._structs import (
25+
sparse_matrix_t,
26+
matrix_descr,
27+
MKL_Complex8,
28+
MKL_Complex16
29+
)
30+
31+
from sparse_dot_mkl._mkl_interface._common import (
32+
_check_return_value,
33+
_out_matrix,
34+
_get_numpy_layout,
35+
_export_mkl,
36+
)
37+
38+
import numpy as np
39+
import scipy.sparse as sp
40+
import ctypes as _ctypes
41+
42+
43+
44+
class MKLQuantumOperator(LinearOperator):
45+
def __init__(self, data, indices, indptr, shape, upper=True, unit=False):
46+
super(MKLQuantumOperator, self).__init__(shape=shape, dtype=np.complex128)
47+
48+
ref = sparse_matrix_t()
49+
50+
ret_val = MKL._mkl_sparse_z_create_csr(
51+
_ctypes.byref(ref),
52+
_ctypes.c_int(SPARSE_INDEX_BASE_ZERO),
53+
MKL.MKL_INT(shape[0]),
54+
MKL.MKL_INT(shape[1]),
55+
indptr[0:-1],
56+
indptr[1:],
57+
indices,
58+
data,
59+
)
60+
61+
self.indptr = indptr
62+
self.indices = indices
63+
self.data = data
64+
self.upper = upper
65+
66+
self.ref = ref
67+
68+
# Check return
69+
_check_return_value(ret_val, MKL._mkl_sparse_z_create_csr.__name__)
70+
71+
if unit:
72+
unit_flag = SPARSE_DIAG_UNIT
73+
else:
74+
unit_flag = SPARSE_DIAG_NON_UNIT
75+
76+
if upper:
77+
upper_flag = SPARSE_FILL_MODE_UPPER
78+
else:
79+
upper_flag = SPARSE_FILL_MODE_FULL
80+
81+
82+
self.descr = matrix_descr(
83+
sparse_matrix_type_t=SPARSE_MATRIX_TYPE_HERMITIAN,
84+
sparse_fill_mode_t=upper_flag,
85+
sparse_diag_type_t=unit_flag
86+
)
87+
88+
def _matvec(self, v, out=None):
89+
90+
if not v.flags.contiguous:
91+
raise ValueError("vector v is not contiguous")
92+
93+
output_shape = (self.shape[0],) if v.ndim == 1 else (self.shape[0], 1)
94+
95+
out = _out_matrix(
96+
output_shape,
97+
np.cdouble,
98+
out_arr=out,
99+
out_t=False
100+
)
101+
ret_val = MKL._mkl_sparse_z_mv(
102+
10,
103+
MKL_Complex16(1.),
104+
self.ref,
105+
self.descr,
106+
v,
107+
MKL_Complex16(0.),
108+
out,
109+
)
110+
111+
_check_return_value(ret_val, MKL._mkl_sparse_z_mv.__name__)
112+
113+
return out
114+
115+
def _matmat(self, V, out=None):
116+
output_shape = (self.shape[0], V.shape[1])
117+
118+
layout_b, ld_b = _get_numpy_layout(V, out)
119+
120+
output_order = "C" if layout_b == LAYOUT_CODE_C else "F"
121+
122+
out = _out_matrix(
123+
output_shape,
124+
np.cdouble,
125+
output_order,
126+
out_arr=out,
127+
out_t=False
128+
)
129+
130+
ret_val = MKL._mkl_sparse_z_mm(
131+
10,
132+
MKL_Complex16(1.),
133+
self.ref,
134+
self.descr,
135+
layout_b,
136+
V,
137+
output_shape[1],
138+
ld_b,
139+
MKL_Complex16(0.),
140+
out.ctypes.data_as(_ctypes.POINTER(_ctypes.c_double)),
141+
ld_b,
142+
143+
)
144+
145+
_check_return_value(ret_val, MKL._mkl_sparse_z_mm.__name__)
146+
147+
return out
148+
149+
@classmethod
150+
def from_csr(cls, csr_mat, upper=True, unit=False):
151+
if upper:
152+
mat = sp.triu(csr_mat, format="csr")
153+
else:
154+
mat = csr_mat
155+
156+
return cls(
157+
mat.data,
158+
mat.indices,
159+
mat.indptr,
160+
mat.shape,
161+
upper,
162+
unit
163+
)
164+
165+
def to_csr(self):
166+
mat = sp.csr_matrix((self.data, self.indices, self.indptr), shape=self.shape, copy=False)
167+
if self.upper:
168+
mat += sp.triu(mat, k=1, format="csr").conj().T
169+
170+
return mat

0 commit comments

Comments
 (0)