Skip to content

Commit 3b49c64

Browse files
committed
Initial draft
1 parent ada7871 commit 3b49c64

11 files changed

+803
-0
lines changed

sparse/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from enum import Enum
44

5+
from . import scheduler # noqa: F401
56
from ._version import __version__, __version_tuple__ # noqa: F401
67

78
__array_api_version__ = "2022.12"

sparse/scheduler/__init__.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from .finch_logic import (
2+
Aggregate,
3+
Alias,
4+
Deferred,
5+
Field,
6+
Immediate,
7+
MapJoin,
8+
Plan,
9+
Produces,
10+
Query,
11+
Reformat,
12+
Relabel,
13+
Reorder,
14+
Subquery,
15+
Table,
16+
)
17+
from .optimize import optimize, propagate_map_queries
18+
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
19+
20+
__all__ = [
21+
"Aggregate",
22+
"Alias",
23+
"Deferred",
24+
"Field",
25+
"Immediate",
26+
"MapJoin",
27+
"Plan",
28+
"Produces",
29+
"Query",
30+
"Reformat",
31+
"Relabel",
32+
"Reorder",
33+
"Subquery",
34+
"Table",
35+
"optimize",
36+
"propagate_map_queries",
37+
"PostOrderDFS",
38+
"PostWalk",
39+
"PreWalk",
40+
]

sparse/scheduler/compiler.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from collections.abc import Hashable
2+
from textwrap import dedent
3+
from typing import Any
4+
5+
from .finch_logic import (
6+
Alias,
7+
Deferred,
8+
Field,
9+
Immediate,
10+
LogicNode,
11+
MapJoin,
12+
Query,
13+
Reformat,
14+
Relabel,
15+
Reorder,
16+
Subquery,
17+
Table,
18+
)
19+
20+
21+
def get_or_insert(dictionary: dict[Hashable, Any], key: Hashable, default: Any) -> Any:
22+
if key in dictionary:
23+
return dictionary[key]
24+
dictionary[key] = default
25+
return default
26+
27+
28+
def get_structure(node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]) -> LogicNode:
29+
match node:
30+
case Field(name):
31+
return get_or_insert(fields, name, Immediate(len(fields) + len(aliases)))
32+
case Alias(name):
33+
return get_or_insert(aliases, name, Immediate(len(fields) + len(aliases)))
34+
case Subquery(Alias(name) as lhs, arg):
35+
if name in aliases:
36+
return aliases[name]
37+
return Subquery(get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases))
38+
case Table(tns, idxs):
39+
return Table(Immediate(type(tns.val)), tuple(get_structure(idx, fields, aliases) for idx in idxs))
40+
case any if any.is_tree():
41+
return any.from_arguments(*[get_structure(arg, fields, aliases) for arg in any.get_arguments()])
42+
case _:
43+
return node
44+
45+
46+
class PointwiseLowerer:
47+
def __init__(self):
48+
self.bound_idxs = []
49+
50+
def __call__(self, ex):
51+
match ex:
52+
case MapJoin(Immediate(val), args):
53+
return f":({val}({','.join([self(arg) for arg in args])}))"
54+
case Reorder(Relabel(Alias(name), idxs_1), idxs_2):
55+
self.bound_idxs.append(idxs_1)
56+
return f":({name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])"
57+
case Reorder(Immediate(val), _):
58+
return val
59+
case Immediate(val):
60+
return val
61+
case _:
62+
raise Exception(f"Unrecognized logic: {ex}")
63+
64+
65+
def compile_pointwise_logic(ex: LogicNode) -> tuple:
66+
ctx = PointwiseLowerer()
67+
code = ctx(ex)
68+
return (code, ctx.bound_idxs)
69+
70+
71+
def compile_logic_constant(ex: LogicNode) -> str:
72+
match ex:
73+
case Immediate(val):
74+
return val
75+
case Deferred(ex, type_):
76+
return f":({ex}::{type_})"
77+
case _:
78+
raise Exception(f"Invalid constant: {ex}")
79+
80+
81+
def intersect(x1: tuple, x2: tuple) -> tuple:
82+
return tuple(x for x in x1 if x in x2)
83+
84+
85+
def with_subsequence(x1: tuple, x2: tuple) -> tuple:
86+
res = list(x2)
87+
indices = [idx for idx, val in enumerate(x2) if val in x1]
88+
for idx, i in enumerate(indices):
89+
res[i] = x1[idx]
90+
return tuple(res)
91+
92+
93+
class LogicLowerer:
94+
def __init__(self, mode: str = "fast"):
95+
self.mode = mode
96+
97+
def __call__(self, ex: LogicNode):
98+
match ex:
99+
case Query(Alias(name), Table(tns, _)):
100+
return f":({name} = {compile_logic_constant(tns)})"
101+
102+
case Query(Alias(_) as lhs, Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2))):
103+
loop_idxs = [idx.name for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2)]
104+
lhs_idxs = [idx.name for idx in idxs_2]
105+
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2))
106+
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})"
107+
for idx in loop_idxs:
108+
if Field(idx) in rhs_idxs:
109+
body = f":(for {idx} = _ \n {body} end)"
110+
elif idx in lhs_idxs:
111+
body = f":(for {idx} = 1:1 \n {body} end)"
112+
113+
result = f"""\
114+
quote
115+
{lhs.name} = {compile_logic_constant(tns)}
116+
@finch mode = {self.mode} begin
117+
{lhs.name} .= {tns.fill_value}
118+
{body}
119+
return {lhs.name}
120+
end
121+
end
122+
"""
123+
return dedent(result)
124+
125+
# TODO: ...
126+
127+
case _:
128+
raise Exception(f"Unrecognized logic: {ex}")
129+
130+
131+
class LogicCompiler:
132+
def __init__(self):
133+
self.ll = LogicLowerer()
134+
135+
def __call__(self, prgm):
136+
# prgm = format_queries(prgm, True) # noqa: F821
137+
return self.ll(prgm)

sparse/scheduler/executor.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from .compiler import LogicCompiler
2+
from .rewrite_tools import gensym
3+
4+
5+
class LogicExecutor:
6+
def __init__(self, ctx, verbose=False):
7+
self.ctx: LogicCompiler = ctx
8+
self.codes = {}
9+
self.verbose = verbose
10+
11+
def __call__(self, prgm):
12+
prgm_structure = prgm
13+
if prgm_structure not in self.codes:
14+
thunk = logic_executor_code(self.ctx, prgm)
15+
self.codes[prgm_structure] = eval(thunk), thunk
16+
17+
f, code = self.codes[prgm_structure]
18+
if self.verbose:
19+
print(code)
20+
return f(prgm)
21+
22+
23+
def logic_executor_code(ctx, prgm):
24+
# jc = JuliaContext()
25+
code = ctx(prgm)
26+
fname = gensym("compute")
27+
return f""":(function {fname}(prgm) \n {code} \n end)"""

0 commit comments

Comments
 (0)