Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/qutip_jax/binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def mul_jaxarray(matrix, value):
# We don't want to check values type in case jax pass a tracer etc.
# But we want to ensure the output is a matrix, thus don't use the
# fast constructor.
return JaxArray(matrix._jxa * value)
return JaxArray._fast_constructor(matrix._jxa * value, shape=matrix.shape)


def matmul_jaxarray(left, right, scale=1, out=None):
Expand Down Expand Up @@ -119,7 +119,7 @@ def kron_jaxarray(left, right):
Compute the Kronecker product of two matrices. This is used to represent
quantum tensor products of vector spaces.
"""
return JaxArray(jnp.kron(left._jxa, right._jxa))
return JaxArray._fast_constructor(jnp.kron(left._jxa, right._jxa))


def pow_jaxarray(matrix, n):
Expand All @@ -138,7 +138,7 @@ def pow_jaxarray(matrix, n):
"""
if matrix.shape[0] != matrix.shape[1]:
raise ValueError("matrix power only works with square matrices")
return JaxArray(jnp.linalg.matrix_power(matrix._jxa, n))
return JaxArray._fast_constructor(jnp.linalg.matrix_power(matrix._jxa, n))


qutip.data.add.add_specialisations(
Expand Down
22 changes: 11 additions & 11 deletions src/qutip_jax/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
]


def zeros_jaxarray(rows, cols):
def zeros_jaxarray(rows, cols, *, dtype=jnp.complex128):
"""
Creates a matrix representation of zeros with the given dimensions.

Expand All @@ -25,10 +25,10 @@ def zeros_jaxarray(rows, cols):
rows, cols : int
The number of rows and columns in the output matrix.
"""
return JaxArray(jnp.zeros((rows, cols), dtype=jnp.complex128))
return JaxArray._fast_constructor(jnp.zeros((rows, cols), dtype=dtype))


def identity_jaxarray(dimensions, scale=None):
def identity_jaxarray(dimensions, scale=None, *, dtype=jnp.complex128):
"""
Creates a square identity matrix of the given dimension.

Expand All @@ -43,11 +43,11 @@ def identity_jaxarray(dimensions, scale=None):
The element which should be placed on the diagonal.
"""
if scale is None:
return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128))
return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128) * scale)
return JaxArray._fast_constructor(jnp.eye(dimensions, dtype=dtype))
return JaxArray._fast_constructor(jnp.eye(dimensions, dtype=dtype) * scale)


def diag_jaxarray(diagonals, offsets=None, shape=None):
def diag_jaxarray(diagonals, offsets=None, shape=None, *, dtype=jnp.complex128):
"""
Constructs a matrix from diagonals and their offsets.

Expand Down Expand Up @@ -108,10 +108,10 @@ def diag_jaxarray(diagonals, offsets=None, shape=None):

if n_rows == n_cols:
# jax diag only create square matrix
out = jnp.zeros((n_rows, n_cols), dtype=jnp.complex128)
out = jnp.zeros((n_rows, n_cols), dtype=dtype)
for offset, diag in zip(offsets, diagonals):
out += jnp.diag(jnp.array(diag), offset)
out = JaxArray(out)
out = JaxArray._fast_constructor(out)
else:
out = jax_from_dense(
qutip.core.data.dense.diags(diagonals, offsets, shape)
Expand All @@ -120,7 +120,7 @@ def diag_jaxarray(diagonals, offsets=None, shape=None):
return out


def one_element_jaxarray(shape, position, value=None):
def one_element_jaxarray(shape, position, value=None, *, dtype=jnp.complex128):
"""
Creates a matrix with only one nonzero element.

Expand All @@ -141,8 +141,8 @@ def one_element_jaxarray(shape, position, value=None):
)
if value is None:
value = 1.0
out = jnp.zeros(shape, dtype=jnp.complex128)
return JaxArray(out.at[position].set(value))
out = jnp.zeros(shape, dtype=dtype)
return JaxArray._fast_constructor(out.at[position].set(value))


qutip.data.zeros.add_specialisations(
Expand Down
16 changes: 9 additions & 7 deletions src/qutip_jax/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class JaxArray(Data):
_jxa: jnp.ndarray
shape: tuple

def __init__(self, data, shape=None, copy=None):
jxa = jnp.array(data, dtype=jnp.complex128)
def __init__(self, data, shape=None, copy=None, *, dtype=jnp.complex128):
jxa = jnp.array(data, dtype=dtype)

if shape is None:
shape = data.shape
Expand All @@ -45,19 +45,19 @@ def __init__(self, data, shape=None, copy=None):
Data.__init__(self, shape)

def copy(self):
return self.__class__(self._jxa, copy=True)
return JaxArray._fast_constructor(self._jxa.copy(), shape=self.shape)

def to_array(self):
return np.array(self._jxa)

def conj(self):
return self.__class__(self._jxa.conj())
return JaxArray._fast_constructor(self._jxa.conj(), shape=self.shape)

def transpose(self):
return self.__class__(self._jxa.T)
return JaxArray._fast_constructor(self._jxa.T, shape=self.shape[::-1])

def adjoint(self):
return self.__class__(self._jxa.T.conj())
return JaxArray._fast_constructor(self._jxa.T.conj(), shape=self.shape[::-1])

def trace(self):
return jnp.trace(self._jxa)
Expand All @@ -81,8 +81,10 @@ def __matmul__(self, other):
return NotImplemented

@classmethod
def _fast_constructor(cls, array, shape):
def _fast_constructor(cls, array, shape=None):
out = cls.__new__(cls)
if shape is None:
shape = array.shape
Data.__init__(out, shape)
out._jxa = array
return out
Expand Down
6 changes: 3 additions & 3 deletions src/qutip_jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def eigs_jaxarray(data, isherm=None, vecs=True, sort='low', eigvals=0):

evals, evecs = _eigs_jaxarray(data._jxa, isherm, vecs, eigvals, low_first)

return (evals, JaxArray(evecs, copy=False)) if vecs else evals
return (evals, JaxArray._fast_constructor(evecs)) if vecs else evals


qutip.data.eigs.add_specialisations(
Expand Down Expand Up @@ -109,7 +109,7 @@ def svd_jaxarray(data, vecs=True, full_matrices=True, hermitian=False):
)
if vecs:
u, s, vh = out
return JaxArray(u, copy=False), s, JaxArray(vh, copy=False)
return JaxArray._fast_constructor(u), s, JaxArray._fast_constructor(vh)
return out


Expand Down Expand Up @@ -160,7 +160,7 @@ def solve_jaxarray(matrix: JaxArray, target: JaxArray, method=None,
else:
raise ValueError(f"Unknown solver {method},"
" 'solve' and 'lstsq' are supported.")
return JaxArray(out, copy=False)
return JaxArray._fast_constructor(out)


qutip.data.solve.add_specialisations(
Expand Down
12 changes: 8 additions & 4 deletions src/qutip_jax/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

@jax.jit
def _cplx2float(arr):
return jnp.stack([arr.real, arr.imag])
if jnp.iscomplexobj(arr):
return jnp.stack([arr.real, arr.imag])
return arr


@jax.jit
def _float2cplx(arr):
return arr[0] + 1j * arr[1]
if arr.ndim == 3:
return arr[0] + 1j * arr[1]
return arr
Comment on lines 15 to +26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test

Copy link
Member Author

@Ericgig Ericgig Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is covered in test_non_cplx128_Diffrax.



class DiffraxIntegrator(Integrator):
Expand Down Expand Up @@ -49,7 +53,7 @@ def _prepare(self):
def dstate(t, y, args):
state = _float2cplx(y)
H, kwargs = args
d_state = H.matmul_data(t, JaxArray(state), **kwargs)
d_state = H.matmul_data(t, JaxArray._fast_constructor(state), **kwargs)
return _cplx2float(d_state._jxa)

def set_state(self, t, state0):
Expand All @@ -61,7 +65,7 @@ def set_state(self, t, state0):
self._is_set = True

def get_state(self, copy=False):
return self.t, JaxArray(_float2cplx(self.state))
return self.t, JaxArray._fast_constructor(_float2cplx(self.state))

def integrate(self, t, copy=False, **kwargs):
sol = diffrax.diffeqsolve(
Expand Down
2 changes: 1 addition & 1 deletion src/qutip_jax/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def indices_jaxarray(matrix, row_perm=None, col_perm=None):
data = data[np.argsort(row_perm), :]
if col_perm is not None:
data = data[:, np.argsort(col_perm)]
return JaxArray(data)
return JaxArray._fast_constructor(data)


def dimensions_jaxarray(matrix, dimensions, order):
Expand Down
37 changes: 28 additions & 9 deletions src/qutip_jax/qobjevo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from .jaxarray import JaxArray
from qutip.core.coefficient import coefficient_builders
from qutip.core.cy.coefficient import Coefficient
from qutip.core.cy.coefficient import Coefficient, coefficient_function_parameters
from qutip import Qobj


Expand All @@ -18,16 +18,26 @@ class JaxJitCoeff(Coefficient):

def __init__(self, func, args={}, **_):
self.func = func
_f_pythonic, _f_parameters = coefficient_function_parameters(func)
if _f_parameters is not None:
args = {key:val for key, val in args.items() if key in _f_parameters}
else:
args = args.copy()
if not _f_pythonic:
raise TypeError("Jitted coefficient should use a pythonic signature.")
Coefficient.__init__(self, args)

@eqx.filter_jit
def __call__(self, t, _args=None, **kwargs):
if _args:
kwargs.update(_args)
args = self.args.copy()
for key in kwargs:
if key in args:
args[key] = kwargs[key]
if kwargs:
args = self.args.copy()
for key in kwargs:
if key in args:
args[key] = kwargs[key]
else:
args = self.args
return self.func(t, **args)

def replace_arguments(self, _args=None, **kwargs):
Expand Down Expand Up @@ -113,25 +123,34 @@ def __init__(self, qobjevo):

constant = JaxJitCoeff(eqx.filter_jit(lambda t, **_: 1.0))

dtype = None

for part in as_list:
if isinstance(part, Qobj):
qobjs.append(part)
self.coeffs.append(constant)
if isinstance(part.data, JaxArray):
dtype = jnp.promote_types(dtype, part.data._jxa.dtype)
elif (
isinstance(part, list) and isinstance(part[0], Qobj)
):
qobjs.append(part[0])
self.coeffs.append(part[1])
if isinstance(part[0], JaxArray):
dtype = jnp.promote_types(dtype, part[0].data._jxa.dtype)
else:
# TODO:
raise NotImplementedError(
"Function based QobjEvo are not supported"
)

if dtype is None:
dtype=jnp.complex128

if qobjs:
shape = qobjs[0].shape
self.batched_data = jnp.zeros(
shape + (len(qobjs),), dtype=np.complex128
shape + (len(qobjs),), dtype=dtype
)
for i, qobj in enumerate(qobjs):
self.batched_data = self.batched_data.at[:, :, i].set(
Expand All @@ -141,7 +160,7 @@ def __init__(self, qobjevo):
@eqx.filter_jit
def _coeff(self, t, **args):
list_coeffs = [coeff(t, **args) for coeff in self.coeffs]
return jnp.array(list_coeffs, dtype=np.complex128)
return jnp.array(list_coeffs, dtype=self.batched_data.dtype)

def __call__(self, t, **kwargs):
return Qobj(self.data(t, **kwargs), dims=self.dims)
Expand All @@ -150,12 +169,12 @@ def __call__(self, t, **kwargs):
def data(self, t, **kwargs):
coeff = self._coeff(t, **kwargs)
data = jnp.dot(self.batched_data, coeff)
return JaxArray(data)
return JaxArray._fast_constructor(data)

@eqx.filter_jit
def matmul_data(self, t, y, **kwargs):
coeffs = self._coeff(t, **kwargs)
out = JaxArray(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa))
out = JaxArray._fast_constructor(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa))
return out

def arguments(self, args):
Expand Down
4 changes: 2 additions & 2 deletions src/qutip_jax/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def column_unstack_jaxarray(matrix, rows):
@jit
def split_columns_jaxarray(matrix):
return [
JaxArray(matrix._jxa[:, k]) for k in range(matrix.shape[1])
JaxArray._fast_constructor(matrix._jxa[:, k:k+1]) for k in range(matrix.shape[1])
]


Expand Down Expand Up @@ -119,7 +119,7 @@ def ptrace_jaxarray(matrix, dims, sel):
+ sel + [nd + q for q in sel]
)

return JaxArray(
return JaxArray._fast_constructor(
_ptrace_core(matrix._jxa, dims2, transpose_idx, dtrace, dkeep)
)

Expand Down
6 changes: 3 additions & 3 deletions src/qutip_jax/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def neg_jaxarray(matrix):
@jit
def adjoint_jaxarray(matrix):
"""Hermitian adjoint (matrix conjugate transpose)."""
return JaxArray(matrix._jxa.T.conj())
return JaxArray._fast_constructor(matrix._jxa.T.conj())


def transpose_jaxarray(matrix):
"""Transpose of a matrix."""
return JaxArray(matrix._jxa.T)
return JaxArray._fast_constructor(matrix._jxa.T)


def conj_jaxarray(matrix):
Expand Down Expand Up @@ -79,7 +79,7 @@ def project_jaxarray(state):
out = _project_bra(state._jxa)
else:
raise ValueError("state must be a ket or a bra.")
return JaxArray(out)
return JaxArray._fast_constructor(out)


qutip.data.neg.add_specialisations(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
key = random.PRNGKey(1234)

def _random_cplx(shape):
return qutip_jax.JaxArray(
return qutip_jax.JaxArray._fast_constructor(
random.normal(key, shape) + 1j*random.normal(key, shape)
)
13 changes: 12 additions & 1 deletion tests/test_jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_init(backend, shape, dtype):
array = backend.array(array)
jax_a = JaxArray(array)
assert isinstance(jax_a, JaxArray)
assert jax_a._jxa.dtype == jax.numpy.complex128
if len(shape) == 1:
shape = shape + (1,)
assert jax_a.shape == shape
Expand Down Expand Up @@ -93,6 +92,18 @@ def test_convert():
assert isinstance(sx.data, JaxArray)


def test_alternative_dtype():
ones = jnp.ones((3, 3))
real_array = JaxArray(ones, dtype=jnp.float64)
cplx_array = JaxArray(ones*1j, dtype=jnp.complex64)
assert (real_array * 5.)._jxa.dtype == jnp.float64
assert (cplx_array + cplx_array)._jxa.dtype == jnp.complex64

cplx_array = JaxArray(ones*1j, dtype=jnp.complex64)
real_array = JaxArray(ones, dtype=jnp.float32)
assert (cplx_array @ real_array)._jxa.dtype == jnp.complex64


def test_extract():
ones = jnp.ones((3, 3))
qobj = qutip.Qobj(ones)
Expand Down
Loading