Skip to content

Commit 219b11f

Browse files
committed
Initial draft
1 parent ada7871 commit 219b11f

11 files changed

+794
-0
lines changed

sparse/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from enum import Enum
44

5+
from . import scheduler # noqa: F401
56
from ._version import __version__, __version_tuple__ # noqa: F401
67

78
__array_api_version__ = "2022.12"

sparse/scheduler/__init__.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from .finch_logic import (
2+
Aggregate,
3+
Alias,
4+
Deferred,
5+
Field,
6+
Immediate,
7+
MapJoin,
8+
Plan,
9+
Produces,
10+
Query,
11+
Reformat,
12+
Relabel,
13+
Reorder,
14+
Subquery,
15+
Table,
16+
)
17+
from .optimize import optimize, propagate_map_queries
18+
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
19+
20+
__all__ = [
21+
"Aggregate",
22+
"Alias",
23+
"Deferred",
24+
"Field",
25+
"Immediate",
26+
"MapJoin",
27+
"Plan",
28+
"Produces",
29+
"Query",
30+
"Reformat",
31+
"Relabel",
32+
"Reorder",
33+
"Subquery",
34+
"Table",
35+
"optimize",
36+
"propagate_map_queries",
37+
"PostOrderDFS",
38+
"PostWalk",
39+
"PreWalk",
40+
]

sparse/scheduler/compiler.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from collections.abc import Hashable
2+
from textwrap import dedent
3+
from typing import Any
4+
5+
from .finch_logic import (
6+
Alias,
7+
Deferred,
8+
Field,
9+
Immediate,
10+
LogicNode,
11+
MapJoin,
12+
Query,
13+
Reformat,
14+
Relabel,
15+
Reorder,
16+
Subquery,
17+
Table,
18+
)
19+
20+
21+
def get_or_insert(dictionary: dict[Hashable, Any], key: Hashable, default: Any) -> Any:
22+
if key in dictionary:
23+
return dictionary[key]
24+
dictionary[key] = default
25+
return default
26+
27+
28+
def get_structure(node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]) -> LogicNode:
29+
match node:
30+
case Field(name) | Alias(name):
31+
return get_or_insert(fields, name, Immediate(len(fields) + len(aliases)))
32+
case Subquery(Alias(name) as lhs, arg):
33+
if name in aliases:
34+
return aliases[name]
35+
return Subquery(get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases))
36+
case Table(tns, idxs):
37+
return Table(Immediate(type(tns.val)), tuple(get_structure(idx, fields, aliases) for idx in idxs))
38+
case any if any.is_tree():
39+
return any.from_arguments(*[get_structure(arg, fields, aliases) for arg in any.get_arguments()])
40+
case _:
41+
return node
42+
43+
44+
class PointwiseLowerer:
45+
def __init__(self):
46+
self.bound_idxs = []
47+
48+
def __call__(self, ex):
49+
match ex:
50+
case MapJoin(Immediate(val), args):
51+
return f":({val}({','.join([self(arg) for arg in args])}))"
52+
case Reorder(Relabel(Alias(name), idxs_1), idxs_2):
53+
self.bound_idxs.append(idxs_1)
54+
return f":({name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])"
55+
case Reorder(Immediate(val), _):
56+
return val
57+
case Immediate(val):
58+
return val
59+
case _:
60+
raise Exception(f"Unrecognized logic: {ex}")
61+
62+
63+
def compile_pointwise_logic(ex: LogicNode) -> tuple:
64+
ctx = PointwiseLowerer()
65+
code = ctx(ex)
66+
return (code, ctx.bound_idxs)
67+
68+
69+
def compile_logic_constant(ex):
70+
match ex:
71+
case Immediate(val):
72+
return val
73+
case Deferred(ex, type_):
74+
return f":({ex}::{type_})"
75+
case _:
76+
raise Exception(f"Invalid constant: {ex}")
77+
78+
79+
def intersect(x1: tuple, x2: tuple) -> tuple:
80+
return tuple(x for x in x1 if x in x2)
81+
82+
83+
def with_subsequence(x1: tuple, x2: tuple) -> tuple:
84+
res = list(x2)
85+
indices = [idx for idx, val in enumerate(x2) if val in x1]
86+
for idx, i in enumerate(indices):
87+
res[i] = x1[idx]
88+
return tuple(res)
89+
90+
91+
class LogicLowerer:
92+
def __init__(self, mode: str = "fast"):
93+
self.mode = mode
94+
95+
def __call__(self, ex):
96+
match ex:
97+
case Query(Alias(name), Table(tns, _)):
98+
return f":({name} = {compile_logic_constant(tns)})"
99+
100+
case Query(Alias(_) as lhs, Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2))):
101+
loop_idxs = [idx.name for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2)]
102+
lhs_idxs = [idx.name for idx in idxs_2]
103+
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2))
104+
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})"
105+
for idx in loop_idxs:
106+
if Field(idx) in rhs_idxs:
107+
body = f":(for {idx} = _ \n {body} end)"
108+
elif idx in lhs_idxs:
109+
body = f":(for {idx} = 1:1 \n {body} end)"
110+
111+
result = f"""\
112+
quote
113+
{lhs.name} = {compile_logic_constant(tns)}
114+
@finch mode = {self.mode} begin
115+
{lhs.name} .= {tns.fill_value}
116+
{body}
117+
return {lhs.name}
118+
end
119+
end
120+
"""
121+
return dedent(result)
122+
123+
# TODO: ...
124+
125+
case _:
126+
raise Exception(f"Unrecognized logic: {ex}")
127+
128+
129+
class LogicCompiler:
130+
def __call__(self, prgm):
131+
prgm = format_queries(prgm, True) # noqa: F821
132+
return LogicLowerer()(prgm)

sparse/scheduler/executor.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from .compiler import LogicCompiler
2+
3+
4+
class LogicExecutor:
5+
def __init__(self, ctx, verbose=False):
6+
self.ctx: LogicCompiler = ctx
7+
self.codes = {}
8+
self.verbose = verbose
9+
10+
def __call__(self, prgm):
11+
prgm_structure = prgm
12+
if prgm_structure not in self.codes:
13+
thunk = logic_executor_code(self.ctx, prgm)
14+
self.codes[prgm_structure] = eval(thunk), thunk
15+
16+
f, code = self.codes[prgm_structure]
17+
if self.verbose:
18+
print(code)
19+
return f(prgm)
20+
21+
22+
def logic_executor_code(ctx, prgm):
23+
pass

0 commit comments

Comments
 (0)