Skip to content

not working with jax (and cupy) #52

@pedrozudo

Description

@pedrozudo

Hi,

I was playing around a little bit your great library. I found that it does not work with jax, and also not cupy, but in the readme it says it would. Just wondering whether you have dropped support for those libraries or whether somethgin is going amiss. Minimal example to reproduce the error is basically the example in the readme but but with jax instead of torch.

import jax.numpy as jnp

# Identity matrix
n_rows = 20
rows = jnp.arange(n_rows)
cols = jnp.arange(n_rows)
data = jnp.ones(n_rows)

solver = CholeskySolverF(n_rows, rows, cols, data, MatrixType.COO)

b = jnp.ones(n_rows)
x = jnp.zeros_like(b)

solver.solve(b, x)
# b = [1, ..., 1]

and here is the error:

solver = CholeskySolverF(n_rows, rows, cols, data, MatrixType.COO)
TypeError: __init__(): incompatible function arguments. The following argument types are supported: 1. __init__(self, n_rows: int, ii: ndarray[dtype=int32, shape=(*), order='C'], jj: ndarray[dtype=int32, shape=(*), order='C'], x: ndarray[dtype=float64, shape=(*), order='C'], type: cholespy.MatrixType) -> None

Invoked with types: cholespy.CholeskySolverF, int, jaxlib.xla_extension.ArrayImpl, jaxlib.xla_extension.ArrayImpl, jaxlib.xla_extension.ArrayImpl, cholespy.MatrixType

A similar error also happens with cupy.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions