Skip to content

Commit 97e7f33

Browse files
committed
WIP: Modify level construction.
1 parent 198b934 commit 97e7f33

File tree

8 files changed

+296
-282
lines changed

8 files changed

+296
-282
lines changed

.github/workflows/publish.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Publish
1+
name: Publish
22
on:
33
workflow_dispatch:
44
jobs:

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ coverage.xml
5050
.hypothesis/
5151
.pytest_cache/
5252
cover/
53+
junit/
5354

5455
# Translations
5556
*.mo

poetry.lock

+157-160
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/finch/__init__.py

-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
from .levels import (
2-
Dense,
3-
Element,
4-
Pattern,
5-
SparseList,
6-
SparseByteMap,
7-
RepeatRLE,
8-
SparseVBL,
9-
SparseCOO,
10-
SparseHash,
11-
Storage,
12-
DenseStorage,
13-
)
141
from .tensor import (
152
Tensor,
163
astype,

src/finch/levels.py

+43-66
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,70 @@
1-
import numpy as np
1+
import abc
2+
23

34
from .julia import jl
4-
from .typing import OrderType
5+
from . import dtypes
6+
from dataclasses import dataclass
57

68

7-
class _Display:
9+
class _Display(abc.ABC):
810
def __repr__(self):
911
return jl.sprint(jl.show, self._obj)
1012

1113
def __str__(self):
1214
return jl.sprint(jl.show, jl.MIME("text/plain"), self._obj)
1315

1416

15-
# LEVEL
17+
class AbstractLeafLevel(abc.ABC):
18+
@abc.abstractmethod
19+
def _construct(self, *, dtype, fill_value):
20+
...
21+
1622

17-
class AbstractLevel(_Display):
18-
pass
23+
# LEVEL
24+
class AbstractLevel(abc.ABC):
25+
@abc.abstractmethod
26+
def _construct(self, *, inner_level):
27+
...
1928

2029

2130
# core levels
22-
31+
@dataclass
2332
class Dense(AbstractLevel):
24-
def __init__(self, lvl, shape=None):
25-
args = [lvl._obj]
26-
if shape is not None:
27-
args.append(shape)
28-
self._obj = jl.Dense(*args)
29-
30-
31-
class Element(AbstractLevel):
32-
def __init__(self, fill_value, data=None):
33-
args = [fill_value]
34-
if data is not None:
35-
args.append(data)
36-
self._obj = jl.Element(*args)
37-
38-
39-
class Pattern(AbstractLevel):
40-
def __init__(self):
41-
self._obj = jl.Pattern()
42-
43-
44-
# advanced levels
45-
46-
class SparseList(AbstractLevel):
47-
def __init__(self, lvl):
48-
self._obj = jl.SparseList(lvl._obj)
49-
50-
51-
class SparseByteMap(AbstractLevel):
52-
def __init__(self, lvl):
53-
self._obj = jl.SparseByteMap(lvl._obj)
54-
33+
dim: int | None = None
34+
index_type: jl.DataType = dtypes.int64
5535

56-
class RepeatRLE(AbstractLevel):
57-
def __init__(self, lvl):
58-
self._obj = jl.RepeatRLE(lvl._obj)
36+
def _construct(self, *, inner_level) -> jl.Dense:
37+
if self.dim is None:
38+
return jl.Dense[self.index_type](inner_level)
5939

40+
return jl.Dense[self.index_type](inner_level, self.dim)
6041

61-
class SparseVBL(AbstractLevel):
62-
def __init__(self, lvl):
63-
self._obj = jl.SparseVBL(lvl._obj)
6442

43+
@dataclass
44+
class Element(AbstractLeafLevel):
45+
def _construct(self, *, dtype: jl.DataType, fill_value) -> jl.Element:
46+
return jl.Element[fill_value, dtype]()
6547

66-
class SparseCOO(AbstractLevel):
67-
def __init__(self, ndim, lvl):
68-
self._obj = jl.SparseCOO[ndim](lvl._obj)
6948

49+
@dataclass
50+
class Pattern(AbstractLeafLevel):
51+
def _construct(self, *, dtype, fill_value) -> jl.Pattern:
52+
from .dtypes import bool
7053

71-
class SparseHash(AbstractLevel):
72-
def __init__(self, ndim, lvl):
73-
self._obj = jl.SparseHash[ndim](lvl._obj)
54+
if dtype != bool:
55+
raise TypeError("`Pattern` can only have `dtype=bool`.")
56+
if dtype(fill_value) != dtype(False):
57+
raise TypeError("`Pattern` can only have `fill_value=False`.")
7458

59+
return jl.Pattern()
7560

76-
# STORAGE
7761

78-
class Storage:
79-
def __init__(self, levels_descr: AbstractLevel, order: OrderType = None):
80-
self.levels_descr = levels_descr
81-
self.order = order if order is not None else "C"
82-
83-
def __str__(self) -> str:
84-
return f"Storage(lvl={str(self.levels_descr)}, order={self.order})"
85-
86-
87-
class DenseStorage(Storage):
88-
def __init__(self, ndim: int, dtype: np.dtype, order: OrderType = None):
89-
lvl = Element(np.int_(0).astype(dtype))
90-
for _ in range(ndim):
91-
lvl = Dense(lvl)
62+
# advanced levels
63+
@dataclass
64+
class SparseList(AbstractLevel):
65+
index_type: jl.DataType = dtypes.int64
66+
pos_type: jl.DataType = dtypes.uint64
67+
crd_type: jl.DataType = dtypes.uint64
9268

93-
super().__init__(levels_descr=lvl, order=order)
69+
def _construct(self, *, inner_level) -> jl.SparseList:
70+
return jl.SparseList[self.index_type, self.pos_type, self.crd_type](inner_level)

src/finch/tensor.py

+48-27
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Any, Optional, Union
1+
from typing import Union
22

33
import numpy as np
44
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
55

66
from .julia import jl
7-
from .levels import _Display, Dense, Element, Storage
7+
from .levels import _Display, Dense, Element
88
from .typing import OrderType, JuliaObj, spmatrix, TupleOf3Arrays
99

1010

@@ -57,12 +57,13 @@ class Tensor(_Display):
5757
array([[0, 1, 2],
5858
[3, 4, 5]])
5959
"""
60+
6061
row_major = "C"
6162
column_major = "F"
6263

6364
def __init__(
6465
self,
65-
obj: Union[np.ndarray, spmatrix, Storage, JuliaObj],
66+
obj: Union[np.ndarray, spmatrix, JuliaObj],
6667
/,
6768
*,
6869
fill_value: np.number = 0.0,
@@ -74,7 +75,9 @@ def __init__(
7475
jl_data = self._from_numpy(obj, fill_value=fill_value)
7576
self._obj = jl_data
7677
elif isinstance(obj, Storage): # from-storage constructor
77-
order = self.preprocess_order(obj.order, self.get_lvl_ndim(obj.levels_descr._obj))
78+
order = self.preprocess_order(
79+
obj.order, self.get_lvl_ndim(obj.levels_descr._obj)
80+
)
7881
self._obj = jl.swizzle(jl.Tensor(obj.levels_descr._obj), *order)
7982
elif jl.isa(obj, jl.Finch.SwizzleArray): # raw-Julia-object constructor
8083
self._obj = obj
@@ -143,25 +146,22 @@ def _order(self) -> tuple[int, ...]:
143146
return jl.typeof(self._obj).parameters[1]
144147

145148
@classmethod
146-
def preprocess_order(
147-
cls, order: OrderType, ndim: int
148-
) -> tuple[int, ...]:
149+
def preprocess_order(cls, order: OrderType, ndim: int) -> tuple[int, ...]:
149150
if order == cls.column_major:
150151
permutation = tuple(range(1, ndim + 1))
151152
elif order == cls.row_major or order is None:
152153
permutation = tuple(range(1, ndim + 1)[::-1])
153154
elif isinstance(order, tuple):
154155
if min(order) == 0:
155156
order = tuple(i + 1 for i in order)
156-
if (
157-
len(order) == ndim and
158-
all([i in order for i in range(1, ndim + 1)])
159-
):
157+
if len(order) == ndim and all([i in order for i in range(1, ndim + 1)]):
160158
permutation = order
161159
else:
162160
raise ValueError(f"Custom order is not a permutation: {order}.")
163161
else:
164-
raise ValueError(f"order must be 'C', 'F' or a tuple, but is: {type(order)}.")
162+
raise ValueError(
163+
f"order must be 'C', 'F' or a tuple, but is: {type(order)}."
164+
)
165165

166166
return permutation
167167

@@ -214,11 +214,11 @@ def permute_dims(self, axes: tuple[int, ...]) -> "Tensor":
214214
new_tensor = Tensor(new_obj)
215215
return new_tensor
216216

217-
def to_device(self, device: Storage) -> "Tensor":
217+
def to_device(self, device) -> "Tensor":
218218
return Tensor(self._from_other_tensor(self, storage=device))
219219

220220
@classmethod
221-
def _from_other_tensor(cls, tensor: "Tensor", storage: Optional[Storage]) -> JuliaObj:
221+
def _from_other_tensor(cls, tensor: "Tensor", storage) -> JuliaObj:
222222
order = cls.preprocess_order(storage.order, tensor.ndim)
223223
return jl.swizzle(
224224
jl.Tensor(storage.levels_descr._obj, tensor._obj.body), *order
@@ -239,7 +239,10 @@ def _from_numpy(cls, arr: np.ndarray, fill_value: np.number) -> JuliaObj:
239239
def _from_scipy_sparse(cls, x) -> JuliaObj:
240240
if x.format == "coo":
241241
return cls.construct_coo_jl_object(
242-
coords=(x.col, x.row), data=x.data, shape=x.shape[::-1], order=Tensor.row_major
242+
coords=(x.col, x.row),
243+
data=x.data,
244+
shape=x.shape[::-1],
245+
order=Tensor.row_major,
243246
)
244247
elif x.format == "csc":
245248
return cls.construct_csc_jl_object(
@@ -255,7 +258,9 @@ def _from_scipy_sparse(cls, x) -> JuliaObj:
255258
raise ValueError(f"Unsupported SciPy format: {type(x)}")
256259

257260
@classmethod
258-
def construct_coo_jl_object(cls, coords, data, shape, order, fill_value=0.0) -> JuliaObj:
261+
def construct_coo_jl_object(
262+
cls, coords, data, shape, order, fill_value=0.0
263+
) -> JuliaObj:
259264
assert len(coords) == 2
260265
ndim = len(shape)
261266
order = cls.preprocess_order(order, ndim)
@@ -264,12 +269,18 @@ def construct_coo_jl_object(cls, coords, data, shape, order, fill_value=0.0) ->
264269
ptr = jl.Vector[jl.Int]([1, len(data) + 1])
265270
tbl = tuple(jl.PlusOneVector(arr) for arr in coords)
266271

267-
jl_data = jl.swizzle(jl.Tensor(jl.SparseCOO[ndim](lvl, shape, ptr, tbl)), *order)
272+
jl_data = jl.swizzle(
273+
jl.Tensor(jl.SparseCOO[ndim](lvl, shape, ptr, tbl)), *order
274+
)
268275
return jl_data
269276

270277
@classmethod
271-
def construct_coo(cls, coords, data, shape, order=row_major, fill_value=0.0) -> "Tensor":
272-
return Tensor(cls.construct_coo_jl_object(coords, data, shape, order, fill_value))
278+
def construct_coo(
279+
cls, coords, data, shape, order=row_major, fill_value=0.0
280+
) -> "Tensor":
281+
return Tensor(
282+
cls.construct_coo_jl_object(coords, data, shape, order, fill_value)
283+
)
273284

274285
@staticmethod
275286
def _construct_compressed2d_jl_object(
@@ -288,22 +299,27 @@ def _construct_compressed2d_jl_object(
288299

289300
lvl = jl.Element(dtype(fill_value), data)
290301
jl_data = jl.swizzle(
291-
jl.Tensor(jl.Dense(jl.SparseList(lvl, shape[0], indptr, indices), shape[1])), *order
302+
jl.Tensor(
303+
jl.Dense(jl.SparseList(lvl, shape[0], indptr, indices), shape[1])
304+
),
305+
*order,
292306
)
293307
return jl_data
294308

295309
@classmethod
296-
def construct_csc_jl_object(cls, arg: TupleOf3Arrays, shape: tuple[int, ...]) -> JuliaObj:
297-
return cls._construct_compressed2d_jl_object(
298-
arg=arg, shape=shape, order=(1, 2)
299-
)
310+
def construct_csc_jl_object(
311+
cls, arg: TupleOf3Arrays, shape: tuple[int, ...]
312+
) -> JuliaObj:
313+
return cls._construct_compressed2d_jl_object(arg=arg, shape=shape, order=(1, 2))
300314

301315
@classmethod
302316
def construct_csc(cls, arg: TupleOf3Arrays, shape: tuple[int, ...]) -> "Tensor":
303317
return Tensor(cls.construct_csc_jl_object(arg, shape))
304318

305319
@classmethod
306-
def construct_csr_jl_object(cls, arg: TupleOf3Arrays, shape: tuple[int, ...]) -> JuliaObj:
320+
def construct_csr_jl_object(
321+
cls, arg: TupleOf3Arrays, shape: tuple[int, ...]
322+
) -> JuliaObj:
307323
return cls._construct_compressed2d_jl_object(
308324
arg=arg, shape=shape[::-1], order=(2, 1)
309325
)
@@ -331,7 +347,9 @@ def construct_csf_jl_object(
331347
for size, indices, indptr in zip(shape[:-1], indices_list, indptr_list):
332348
lvl = jl.SparseList(lvl, size, indptr, indices)
333349

334-
jl_data = jl.swizzle(jl.Tensor(jl.Dense(lvl, shape[-1])), *range(1, len(shape) + 1))
350+
jl_data = jl.swizzle(
351+
jl.Tensor(jl.Dense(lvl, shape[-1])), *range(1, len(shape) + 1)
352+
)
335353
return jl_data
336354

337355
@classmethod
@@ -377,7 +395,9 @@ def _slice_plus_one(s: slice, size: int) -> range:
377395

378396
if s.stop is not None:
379397
stop_offset = 2 if step < 0 else 0
380-
stop = normalize_axis_index(s.stop, size) + stop_offset if s.stop < size else size
398+
stop = (
399+
normalize_axis_index(s.stop, size) + stop_offset if s.stop < size else size
400+
)
381401
else:
382402
stop = stop_default
383403

@@ -429,6 +449,7 @@ def _expand_ellipsis(key: tuple, shape: tuple[int, ...]) -> tuple:
429449
key = new_key
430450
return key
431451

452+
432453
def _add_missing_dims(key: tuple, shape: tuple[int, ...]) -> tuple:
433454
for i in range(len(key), len(shape)):
434455
key = key + (jl.range(start=1, stop=shape[i]),)

0 commit comments

Comments
 (0)