Skip to content

Commit 9f2e105

Browse files
committed
Init.
0 parents  commit 9f2e105

File tree

4 files changed

+393
-0
lines changed

4 files changed

+393
-0
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# splax
2+
3+
Writing a simple probabilistic language using interpreter transformations in JAX.

pyproject.toml

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
[tool.poetry]
2+
name = "splax"
3+
version = "0.1.0"
4+
description = ""
5+
authors = ["McCoy R. Becker <[email protected]>"]
6+
readme = "README.md"
7+
8+
[tool.poetry.dependencies]
9+
python = "<3.13,>=3.9"
10+
jax = "^0.4.19"
11+
rich = "^13.6.0"
12+
beartype = "^0.16.4"
13+
jaxtyping = "^0.2.23"
14+
plum-dispatch = "^2.2.2"
15+
16+
[tool.poetry.group.dev.dependencies]
17+
ruff = "^0.1.3"
18+
pytest = "^7.2.0"
19+
coverage = "^7.0.0"
20+
pytest-benchmark = "^4.0.0"
21+
pytest-xdist = {version = "^3.2.0", extras = ["psutil"] }
22+
xdoctest = "^1.1.0"
23+
safety = "^2.3.5"
24+
jupyterlab = "^3.5.1"
25+
matplotlib = "^3.6.2"
26+
seaborn = "^0.12.1"
27+
28+
[tool.ruff]
29+
# Exclude a variety of commonly ignored directories.
30+
exclude = [
31+
".bzr",
32+
".direnv",
33+
".eggs",
34+
".git",
35+
".git-rewrite",
36+
".hg",
37+
".mypy_cache",
38+
".nox",
39+
".pants.d",
40+
".pytype",
41+
".ruff_cache",
42+
".svn",
43+
".tox",
44+
".venv",
45+
"__pypackages__",
46+
"_build",
47+
"buck-out",
48+
"build",
49+
"dist",
50+
"node_modules",
51+
"venv",
52+
]
53+
54+
# Same as Black.
55+
line-length = 88
56+
indent-width = 4
57+
58+
[tool.ruff.lint]
59+
# Enable the isort rules.
60+
extend-select = ["I"]
61+
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
62+
select = ["E4", "E7", "E9", "F"]
63+
ignore = []
64+
65+
# Allow fix for all enabled rules (when `--fix`) is provided.
66+
fixable = ["ALL"]
67+
unfixable = []
68+
69+
# Allow unused variables when underscore-prefixed.
70+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
71+
72+
[tool.ruff.format]
73+
# Like Black, use double quotes for strings.
74+
quote-style = "double"
75+
76+
# Like Black, indent with spaces, rather than tabs.
77+
indent-style = "space"
78+
79+
# Like Black, respect magic trailing commas.
80+
skip-magic-trailing-comma = false
81+
82+
# Like Black, automatically detect the appropriate line ending.
83+
line-ending = "auto"
84+
85+
[build-system]
86+
requires = ["poetry-core"]
87+
build-backend = "poetry.core.masonry.api"
88+

splax/__init__.py

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright 2022 MIT Probabilistic Computing Project
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
import jax.tree_util as jtu
17+
from jax.util import safe_zip
18+
import typing
19+
20+
import beartype.typing as btyping
21+
import jax
22+
import jax.numpy as jnp
23+
import jaxtyping as jtyping
24+
import numpy as np
25+
from beartype import BeartypeConf, beartype
26+
from plum import dispatch
27+
import copy
28+
import dataclasses
29+
import functools
30+
from contextlib import contextmanager
31+
32+
from jax import api_util, lax
33+
from jax import core as jc
34+
from jax import tree_util as jtu
35+
from jax.extend import linear_util as lu
36+
from jax.interpreters import partial_eval as pe
37+
from jax.util import safe_map
38+
39+
40+
##########
41+
# Typing #
42+
##########
43+
44+
45+
Dataclass = typing.Any
46+
PrettyPrintable = typing.Any
47+
PRNGKey = jtyping.UInt[jtyping.Array, "..."]
48+
FloatArray = typing.Union[float, jtyping.Float[jtyping.Array, "..."]]
49+
BoolArray = typing.Union[bool, jtyping.Bool[jtyping.Array, "..."]]
50+
IntArray = typing.Union[int, jtyping.Int[jtyping.Array, "..."]]
51+
Any = typing.Any
52+
Union = typing.Union
53+
Callable = btyping.Callable
54+
Sequence = typing.Sequence
55+
Tuple = btyping.Tuple
56+
NamedTuple = btyping.NamedTuple
57+
Dict = btyping.Dict
58+
List = btyping.List
59+
Iterable = btyping.Iterable
60+
Generator = btyping.Generator
61+
Hashable = btyping.Hashable
62+
FrozenSet = btyping.FrozenSet
63+
Optional = btyping.Optional
64+
Type = btyping.Type
65+
Int = int
66+
Float = float
67+
Bool = bool
68+
String = str
69+
Value = Any
70+
Generic = btyping.Generic
71+
TypeVar = btyping.TypeVar
72+
73+
conf = BeartypeConf()
74+
typecheck = beartype(conf=conf)
75+
76+
77+
#################
78+
# Hashable dict #
79+
#################
80+
81+
82+
class HashableDict(dict):
83+
"""
84+
A hashable dictionary class - allowing the
85+
usage of `dict`-like instances as JAX JIT cache keys
86+
(and allowing their usage with JAX `static_argnums` in `jax.jit`).
87+
"""
88+
89+
def __key(self):
90+
return tuple((k, self[k]) for k in sorted(self, key=hash))
91+
92+
def __hash__(self):
93+
return hash(self.__key())
94+
95+
def __eq__(self, other):
96+
return self.__key() == other.__key()
97+
98+
99+
# This ensures that static keys are always sorted
100+
# in a pre-determined order - which is important
101+
# for `Pytree` structure comparison.
102+
def _flatten(x: HashableDict):
103+
s = {k: v for (k, v) in sorted(x.items(), key=lambda v: hash(v[0]))}
104+
return (list(s.values()), list(s.keys()))
105+
106+
107+
jtu.register_pytree_node(
108+
HashableDict,
109+
_flatten,
110+
lambda keys, values: HashableDict(safe_zip(keys, values)),
111+
)
112+
113+
114+
def hashable_dict():
115+
return HashableDict({})
116+
117+
118+
###########
119+
# Pytrees #
120+
###########
121+
122+
123+
class Pytree:
124+
"""> Abstract base class which registers a class with JAX's `Pytree`
125+
system.
126+
127+
Users who mixin this ABC for class definitions are required to
128+
implement `flatten` below. In turn, instances of the class gain
129+
access to a large set of utility functions for working with `Pytree`
130+
data, as well as the ability to use `jax.tree_util` Pytree
131+
functionality.
132+
"""
133+
134+
def __init_subclass__(cls, **kwargs):
135+
jtu.register_pytree_node(
136+
cls,
137+
cls.flatten,
138+
cls.unflatten,
139+
)
140+
141+
@abc.abstractmethod
142+
def flatten(self) -> Tuple[Tuple, Tuple]:
143+
pass
144+
145+
@classmethod
146+
def unflatten(cls, data, xs):
147+
return cls(*data, *xs)
148+
149+
@classmethod
150+
def new(cls, *args, **kwargs):
151+
return cls(*args, **kwargs)
152+
153+
154+
############################################################
155+
# Staging a function to a closed Jaxpr, for interpretation #
156+
############################################################
157+
158+
159+
def get_shaped_aval(x):
160+
return jc.raise_to_shaped(jc.get_aval(x))
161+
162+
163+
# This will create a cache entry for a function with abstract
164+
# arguments (each has a dtype and shape).
165+
# It is useful to cache the result, because you may stage the same
166+
# function multiple times - if you write and compose multiple transformations.
167+
@lu.cache
168+
def cached_stage_dynamic(flat_fun, in_avals):
169+
# https://github.com/google/jax/blob/1c66ac532b4ef2ad64f9e0859ede329f4fbd0041/jax/_src/interpreters/partial_eval.py#L2268-L2305
170+
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
171+
172+
typed_jaxpr = jc.ClosedJaxpr(jaxpr, consts)
173+
return typed_jaxpr
174+
175+
176+
def stage(f):
177+
"""Returns a function that stages a function to a ClosedJaxpr."""
178+
179+
def wrapped(*args, **kwargs):
180+
fun = lu.wrap_init(f, kwargs)
181+
flat_args, in_tree = jtu.tree_flatten(args)
182+
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
183+
flat_avals = safe_map(get_shaped_aval, flat_args)
184+
typed_jaxpr = cached_stage_dynamic(flat_fun, tuple(flat_avals))
185+
return typed_jaxpr, (flat_args, in_tree, out_tree)
186+
187+
return wrapped
188+
189+
190+
#######################
191+
# Forward interpreter #
192+
#######################
193+
194+
VarOrLiteral = Union[jc.Var, jc.Literal]
195+
196+
197+
@dataclasses.dataclass
198+
class Environment(Pytree):
199+
"""Keeps track of variables and their values during interpretation."""
200+
201+
env: HashableDict[jc.Var, Value]
202+
203+
def flatten(self):
204+
return (self.env,), ()
205+
206+
@classmethod
207+
def new(cls):
208+
return Environment(hashable_dict())
209+
210+
def read(self, var: VarOrLiteral) -> Value:
211+
if isinstance(var, jc.Literal):
212+
return var.val
213+
else:
214+
return self.env.get(var.count)
215+
216+
def write(self, var: VarOrLiteral, cell: Value) -> Value:
217+
if isinstance(var, jc.Literal):
218+
return cell
219+
cur_cell = self.read(var)
220+
if isinstance(var, jc.DropVar):
221+
return cur_cell
222+
self.env[var.count] = cell
223+
return self.env[var.count]
224+
225+
def __getitem__(self, var: VarOrLiteral) -> Value:
226+
return self.read(var)
227+
228+
def __setitem__(self, key, val):
229+
raise ValueError(
230+
"Environments do not support __setitem__. Please use the "
231+
"`write` method instead."
232+
)
233+
234+
def __contains__(self, var: VarOrLiteral):
235+
if isinstance(var, jc.Literal):
236+
return True
237+
return var in self.env
238+
239+
def copy(self):
240+
return copy.copy(self)
241+
242+
243+
###############################
244+
# Forward masking interpreter #
245+
###############################
246+
247+
248+
@dataclasses.dataclass
249+
class ForwardInterpreter(Pytree):
250+
def flatten(self):
251+
return (), ()
252+
253+
# This produces an instance of `Interpreter`
254+
# as a context manager - to allow us to control error stack traces,
255+
# if required.
256+
@classmethod
257+
@contextmanager
258+
def new(cls):
259+
try:
260+
yield ForwardInterpreter()
261+
except Exception as e:
262+
raise e
263+
264+
def _eval_jaxpr_forward(
265+
self,
266+
jaxpr: jc.Jaxpr,
267+
consts: List[Value],
268+
flat_args: List[Value],
269+
):
270+
env = Environment.new()
271+
safe_map(env.write, jaxpr.constvars, consts)
272+
safe_map(env.write, jaxpr.invars, flat_args)
273+
for eqn in jaxpr.eqns:
274+
invals = safe_map(env.read, eqn.invars)
275+
subfuns, params = eqn.primitive.get_bind_params(eqn.params)
276+
args = subfuns + invals
277+
custom_rule = masking_registry[eqn.primitive]
278+
outvals = eqn.primitive.bind(*args, **params)
279+
if not eqn.primitive.multiple_results:
280+
outvals = [outvals]
281+
safe_map(env.write, eqn.outvars, outvals)
282+
283+
return safe_map(env.read, jaxpr.outvars)
284+
285+
def run_interpreter(self, fn, args, **kwargs):
286+
def _inner(*args):
287+
return fn(*args, **kwargs)
288+
289+
closed_jaxpr, (flat_args, _, out_tree) = stage(_inner)(*args)
290+
flat_mask_flags = jtu.tree_leaves(mask_flags)
291+
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
292+
flat_out = self._eval_jaxpr_forward(jaxpr, consts, flat_args)
293+
return jtu.tree_unflatten(out_tree(), flat_out)
294+
295+
296+
def forward(f: Callable):
297+
@functools.wraps(f)
298+
def wrapped(args, mask_flags):
299+
with ForwardInterpreter.new() as interpreter:
300+
return interpreter.run_interpreter(f, args, mask_flags)
301+
302+
return wrapped

tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)