Skip to content

Commit 3e5894a

Browse files
committed
Initial draft
1 parent ada7871 commit 3e5894a

11 files changed

+764
-0
lines changed

sparse/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from enum import Enum
44

5+
from . import scheduler # noqa: F401
56
from ._version import __version__, __version_tuple__ # noqa: F401
67

78
__array_api_version__ = "2022.12"

sparse/scheduler/__init__.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from .finch_logic import (
2+
Aggregate,
3+
Alias,
4+
Deferred,
5+
Field,
6+
Immediate,
7+
MapJoin,
8+
Plan,
9+
Produces,
10+
Query,
11+
Reformat,
12+
Relabel,
13+
Reorder,
14+
Subquery,
15+
Table,
16+
)
17+
from .optimize import optimize, propagate_map_queries
18+
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
19+
20+
__all__ = [
21+
"Aggregate",
22+
"Alias",
23+
"Deferred",
24+
"Field",
25+
"Immediate",
26+
"MapJoin",
27+
"Plan",
28+
"Produces",
29+
"Query",
30+
"Reformat",
31+
"Relabel",
32+
"Reorder",
33+
"Subquery",
34+
"Table",
35+
"optimize",
36+
"propagate_map_queries",
37+
"PostOrderDFS",
38+
"PostWalk",
39+
"PreWalk",
40+
]

sparse/scheduler/compiler.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from textwrap import dedent
2+
3+
from .finch_logic import Alias, Deferred, Field, Immediate, LogicNode, MapJoin, Query, Reformat, Relabel, Reorder, Table
4+
5+
6+
class PointwiseLowerer:
7+
def __init__(self):
8+
self.bound_idxs = []
9+
10+
def __call__(self, ex):
11+
match ex:
12+
case MapJoin(op, args) if isinstance(op, Immediate):
13+
return f":({op.val}({','.join([self(arg) for arg in args])}))"
14+
case Reorder(Relabel(arg, idxs_1), idxs_2) if isinstance(arg, Alias):
15+
self.bound_idxs.append(idxs_1)
16+
return f":({arg.name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])"
17+
case Reorder(arg, _) if isinstance(arg, Immediate):
18+
return arg.val
19+
case Immediate(val):
20+
return val
21+
case _:
22+
raise Exception(f"Unrecognized logic: {ex}")
23+
24+
25+
def compile_pointwise_logic(ex: LogicNode) -> tuple:
26+
ctx = PointwiseLowerer()
27+
code = ctx(ex)
28+
return (code, ctx.bound_idxs)
29+
30+
31+
def compile_logic_constant(ex):
32+
match ex:
33+
case Immediate(val):
34+
return val
35+
case Deferred(ex, type_):
36+
return f":({ex}::{type_})"
37+
case _:
38+
raise Exception(f"Invalid constant: {ex}")
39+
40+
41+
class LogicLowerer:
42+
def __init__(self, mode: str = "fast"):
43+
self.mode = mode
44+
45+
def __call__(self, ex):
46+
match ex:
47+
case Query(lhs, Table(tns, _)) if isinstance(lhs, Alias):
48+
return f":({lhs.name} = {compile_logic_constant(tns)})"
49+
50+
case Query(lhs, Reformat(tns, Reorder(Relabel(arg, idxs_1), idxs_2))) if isinstance(
51+
lhs, Alias
52+
) and isinstance(arg, Alias):
53+
loop_idxs = [idx.name for idx in withsubsequence(intersect(idxs_1, idxs_2), idxs_2)] # noqa: F821
54+
lhs_idxs = [idx.name for idx in idxs_2]
55+
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2))
56+
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})"
57+
for idx in loop_idxs:
58+
if Field(idx) in rhs_idxs:
59+
body = f":(for {idx} = _ \n {body} end)"
60+
elif idx in lhs_idxs:
61+
body = f":(for {idx} = 1:1 \n {body} end)"
62+
63+
result = f"""\
64+
quote
65+
{lhs.name} = {compile_logic_constant(tns)}
66+
@finch mode = {self.mode} begin
67+
{lhs.name} .= {tns.fill_value}
68+
{body}
69+
return {lhs.name}
70+
end
71+
end
72+
"""
73+
return dedent(result)
74+
75+
76+
class LogicCompiler:
77+
def __call__(self, prgm):
78+
prgm = format_queries(prgm, True) # noqa: F821
79+
return LogicLowerer()(prgm)

sparse/scheduler/executor.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from .compiler import LogicCompiler
2+
3+
4+
class LogicExecutor:
5+
def __init__(self, ctx, verbose=False):
6+
self.ctx: LogicCompiler = ctx
7+
self.codes = {}
8+
self.verbose = verbose
9+
10+
def __call__(self, prgm):
11+
prgm_structure = prgm
12+
if prgm_structure not in self.codes:
13+
thunk = logic_executor_code(self.ctx, prgm)
14+
self.codes[prgm_structure] = eval(thunk), thunk
15+
16+
f, code = self.codes[prgm_structure]
17+
if self.verbose:
18+
print(code)
19+
return f(prgm)
20+
21+
22+
def get_structure():
23+
pass
24+
25+
26+
def logic_executor_code(ctx, prgm):
27+
pass

0 commit comments

Comments
 (0)