Skip to content

Modify level construction to allow querying #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Publish
name: Publish
on:
workflow_dispatch:
jobs:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
junit/

# Translations
*.mo
Expand Down
17 changes: 2 additions & 15 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
from .levels import (
Dense,
Element,
Pattern,
SparseList,
SparseByteMap,
RepeatRLE,
SparseVBL,
SparseCOO,
SparseHash,
Storage,
DenseStorage,
)
from .tensor import (
Tensor,
asarray,
Expand All @@ -34,7 +21,7 @@
compute,
)
from .dtypes import (
int_,
int,
int8,
int16,
int32,
Expand Down Expand Up @@ -71,7 +58,7 @@
"tensordot",
"matmul",
"permute_dims",
"int_",
"int",
"int8",
"int16",
"int32",
Expand Down
33 changes: 17 additions & 16 deletions src/finch/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from .julia import jl
from .typing import DType


int_: jl.DataType = jl.Int
int8: jl.DataType = jl.Int8
int16: jl.DataType = jl.Int16
int32: jl.DataType = jl.Int32
int64: jl.DataType = jl.Int64
uint: jl.DataType = jl.UInt
uint8: jl.DataType = jl.UInt8
uint16: jl.DataType = jl.UInt16
uint32: jl.DataType = jl.UInt32
uint64: jl.DataType = jl.UInt64
float16: jl.DataType = jl.Float16
float32: jl.DataType = jl.Float32
float64: jl.DataType = jl.Float64
complex64: jl.DataType = jl.ComplexF32
complex128: jl.DataType = jl.ComplexF64
bool: jl.DataType = jl.Bool
int: DType = jl.Int
int8: DType = jl.Int8
int16: DType = jl.Int16
int32: DType = jl.Int32
int64: DType = jl.Int64
uint: DType = jl.UInt
uint8: DType = jl.UInt8
uint16: DType = jl.UInt16
uint32: DType = jl.UInt32
uint64: DType = jl.UInt64
float16: DType = jl.Float16
float32: DType = jl.Float32
float64: DType = jl.Float64
complex64: DType = jl.ComplexF32
complex128: DType = jl.ComplexF64
bool: DType = jl.Bool
108 changes: 108 additions & 0 deletions src/finch/formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import abc
import typing

import numpy as np

from .julia import jl
from . import levels as levels_module
from . import utils
from .typing import DType
from dataclasses import dataclass


@dataclass
class Format:
levels: tuple[levels_module.AbstractLevel, ...]
order: tuple[int, ...]
leaf: levels_module.AbstractLeafLevel

def __post_init__(self) -> None:
utils.check_valid_order(self.order, ndim=len(self.levels))

@property
def ndim(self) -> int:
return len(self.order)

def _construct(self, *, fill_value, dtype: DType, data: np.ndarray | None = None):
if data is not None:
data_order = tuple(
x[0]
for x in sorted(enumerate(reversed(data.strides)), key=lambda x: x[1])
)
data_inv_order = utils.get_inverse_order(data_order)
data_raw = data.transpose(data_inv_order)
if not data.flags.f_contiguous:
data_raw = data_raw.copy(order="F")

out_level = self.leaf._construct(dtype=dtype, fill_value=fill_value)
for level in reversed(self.levels):
out_level = level._construct(inner_level=out_level)

reversed_order = tuple(reversed(self.order))
swizzle_args = map(lambda x: reversed_order[x] + 1, data_order)
if data is None:
return jl.swizzle(jl.Tensor(out_level), *swizzle_args)

return jl.swizzle(jl.Tensor(out_level, data_raw), *swizzle_args)


CSR = Format(
levels=(levels_module.Dense(), levels_module.SparseList()),
order=(0, 1),
leaf=levels_module.Element(),
)
CSC = Format(
levels=(levels_module.Dense(), levels_module.SparseList()),
order=(1, 0),
leaf=levels_module.Element(),
)


class FlexibleFormat(abc.ABC):
def _construct(self, *, ndim: int, fill_value, dtype: DType, data=None):
return self._get_format(ndim)._construct(
fill_value=fill_value, dtype=dtype, data=data
)

@abc.abstractmethod
def _get_format(self, ndim: int, /) -> Format:
pass


@dataclass
class Dense(FlexibleFormat):
order: typing.Literal["C", "F"] | tuple[int, ...] = "C"
shape: tuple[int | None, ...] | None = None

def __post_init__(self) -> None:
if isinstance(self.order, tuple):
utils.check_valid_order(self.order)

if self.shape is not None and len(self.order) != len(self.shape):
raise ValueError(
f"len(self.order) != len(self.shape), {self.order}, {self.shape}"
)

def _get_format(self, ndim: int) -> Format:
super()._get_format(ndim)
match self.order:
case "C":
order = tuple(range(ndim))
case "F":
order = tuple(reversed(range(ndim)))
case _:
order = self.order

utils.check_valid_order(order, ndim=ndim)

shape = self.shape
if shape is None:
shape = (None,) * ndim

if len(shape) != ndim:
raise ValueError(f"len(self.shape != ndim), {shape=}, {ndim=}")

topological_shape = utils.get_topological_shape(shape, order=order)
lvls = tuple(levels_module.Dense(dim=dim) for dim in topological_shape)

return Format(levels=lvls, order=order, leaf=levels_module.Element())
112 changes: 41 additions & 71 deletions src/finch/levels.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,63 @@
import numpy as np
import abc

from .julia import jl
from .typing import OrderType, DType

from .julia import jl
from . import dtypes
from dataclasses import dataclass
from .typing import DType

class _Display:
def __repr__(self):
return jl.sprint(jl.show, self._obj)

def __str__(self):
return jl.sprint(jl.show, jl.MIME("text/plain"), self._obj)
class AbstractLeafLevel(abc.ABC):
@abc.abstractmethod
def _construct(self, *, dtype, fill_value):
...


# LEVEL

class AbstractLevel(_Display):
pass
class AbstractLevel(abc.ABC):
@abc.abstractmethod
def _construct(self, *, inner_level):
...


# core levels

@dataclass
class Dense(AbstractLevel):
def __init__(self, lvl, shape=None):
args = [lvl._obj]
if shape is not None:
args.append(shape)
self._obj = jl.Dense(*args)


class Element(AbstractLevel):
def __init__(self, fill_value, data=None):
args = [fill_value]
if data is not None:
args.append(data)
self._obj = jl.Element(*args)
dim: int | None = None
index_type: DType = dtypes.int64

def _construct(self, *, inner_level) -> jl.Dense:
if self.dim is None:
return jl.Dense[self.index_type](inner_level)

class Pattern(AbstractLevel):
def __init__(self):
self._obj = jl.Pattern()


# advanced levels

class SparseList(AbstractLevel):
def __init__(self, lvl):
self._obj = jl.SparseList(lvl._obj)
return jl.Dense[self.index_type](inner_level, self.dim)


class SparseByteMap(AbstractLevel):
def __init__(self, lvl):
self._obj = jl.SparseByteMap(lvl._obj)
@dataclass
class Element(AbstractLeafLevel):
def _construct(self, *, dtype: DType, fill_value) -> jl.Element:
return jl.Element[fill_value, dtype]()


class RepeatRLE(AbstractLevel):
def __init__(self, lvl):
self._obj = jl.RepeatRLE(lvl._obj)
@dataclass
class Pattern(AbstractLeafLevel):
def _construct(self, *, dtype, fill_value) -> jl.Pattern:
from .dtypes import bool

if dtype != bool:
raise TypeError("`Pattern` can only have `dtype=bool`.")
if dtype(fill_value) != dtype(False):
raise TypeError("`Pattern` can only have `fill_value=False`.")

class SparseVBL(AbstractLevel):
def __init__(self, lvl):
self._obj = jl.SparseVBL(lvl._obj)
return jl.Pattern()


class SparseCOO(AbstractLevel):
def __init__(self, ndim, lvl):
self._obj = jl.SparseCOO[ndim](lvl._obj)


class SparseHash(AbstractLevel):
def __init__(self, ndim, lvl):
self._obj = jl.SparseHash[ndim](lvl._obj)


# STORAGE

class Storage:
def __init__(self, levels_descr: AbstractLevel, order: OrderType = None):
self.levels_descr = levels_descr
self.order = order if order is not None else "C"

def __str__(self) -> str:
return f"Storage(lvl={str(self.levels_descr)}, order={self.order})"


class DenseStorage(Storage):
def __init__(self, ndim: int, dtype: DType, order: OrderType = None):
lvl = Element(dtype(0))
for _ in range(ndim):
lvl = Dense(lvl)
# advanced levels
@dataclass
class SparseList(AbstractLevel):
index_type: DType = dtypes.int64
pos_type: DType = dtypes.uint64
crd_type: DType = dtypes.uint64

super().__init__(levels_descr=lvl, order=order)
def _construct(self, *, inner_level) -> jl.SparseList:
return jl.SparseList[self.index_type, self.pos_type, self.crd_type](inner_level)
Loading
Loading