|
| 1 | +# Copyright 2022 MIT Probabilistic Computing Project |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import abc |
| 16 | +import jax.tree_util as jtu |
| 17 | +from jax.util import safe_zip |
| 18 | +import typing |
| 19 | + |
| 20 | +import beartype.typing as btyping |
| 21 | +import jax |
| 22 | +import jax.numpy as jnp |
| 23 | +import jaxtyping as jtyping |
| 24 | +import numpy as np |
| 25 | +from beartype import BeartypeConf, beartype |
| 26 | +from plum import dispatch |
| 27 | +import copy |
| 28 | +import dataclasses |
| 29 | +import functools |
| 30 | +from contextlib import contextmanager |
| 31 | + |
| 32 | +from jax import api_util, lax |
| 33 | +from jax import core as jc |
| 34 | +from jax import tree_util as jtu |
| 35 | +from jax.extend import linear_util as lu |
| 36 | +from jax.interpreters import partial_eval as pe |
| 37 | +from jax.util import safe_map |
| 38 | + |
| 39 | + |
| 40 | +########## |
| 41 | +# Typing # |
| 42 | +########## |
| 43 | + |
| 44 | + |
| 45 | +Dataclass = typing.Any |
| 46 | +PrettyPrintable = typing.Any |
| 47 | +PRNGKey = jtyping.UInt[jtyping.Array, "..."] |
| 48 | +FloatArray = typing.Union[float, jtyping.Float[jtyping.Array, "..."]] |
| 49 | +BoolArray = typing.Union[bool, jtyping.Bool[jtyping.Array, "..."]] |
| 50 | +IntArray = typing.Union[int, jtyping.Int[jtyping.Array, "..."]] |
| 51 | +Any = typing.Any |
| 52 | +Union = typing.Union |
| 53 | +Callable = btyping.Callable |
| 54 | +Sequence = typing.Sequence |
| 55 | +Tuple = btyping.Tuple |
| 56 | +NamedTuple = btyping.NamedTuple |
| 57 | +Dict = btyping.Dict |
| 58 | +List = btyping.List |
| 59 | +Iterable = btyping.Iterable |
| 60 | +Generator = btyping.Generator |
| 61 | +Hashable = btyping.Hashable |
| 62 | +FrozenSet = btyping.FrozenSet |
| 63 | +Optional = btyping.Optional |
| 64 | +Type = btyping.Type |
| 65 | +Int = int |
| 66 | +Float = float |
| 67 | +Bool = bool |
| 68 | +String = str |
| 69 | +Value = Any |
| 70 | +Generic = btyping.Generic |
| 71 | +TypeVar = btyping.TypeVar |
| 72 | + |
| 73 | +conf = BeartypeConf() |
| 74 | +typecheck = beartype(conf=conf) |
| 75 | + |
| 76 | + |
| 77 | +################# |
| 78 | +# Hashable dict # |
| 79 | +################# |
| 80 | + |
| 81 | + |
| 82 | +class HashableDict(dict): |
| 83 | + """ |
| 84 | + A hashable dictionary class - allowing the |
| 85 | + usage of `dict`-like instances as JAX JIT cache keys |
| 86 | + (and allowing their usage with JAX `static_argnums` in `jax.jit`). |
| 87 | + """ |
| 88 | + |
| 89 | + def __key(self): |
| 90 | + return tuple((k, self[k]) for k in sorted(self, key=hash)) |
| 91 | + |
| 92 | + def __hash__(self): |
| 93 | + return hash(self.__key()) |
| 94 | + |
| 95 | + def __eq__(self, other): |
| 96 | + return self.__key() == other.__key() |
| 97 | + |
| 98 | + |
| 99 | +# This ensures that static keys are always sorted |
| 100 | +# in a pre-determined order - which is important |
| 101 | +# for `Pytree` structure comparison. |
| 102 | +def _flatten(x: HashableDict): |
| 103 | + s = {k: v for (k, v) in sorted(x.items(), key=lambda v: hash(v[0]))} |
| 104 | + return (list(s.values()), list(s.keys())) |
| 105 | + |
| 106 | + |
| 107 | +jtu.register_pytree_node( |
| 108 | + HashableDict, |
| 109 | + _flatten, |
| 110 | + lambda keys, values: HashableDict(safe_zip(keys, values)), |
| 111 | +) |
| 112 | + |
| 113 | + |
| 114 | +def hashable_dict(): |
| 115 | + return HashableDict({}) |
| 116 | + |
| 117 | + |
| 118 | +########### |
| 119 | +# Pytrees # |
| 120 | +########### |
| 121 | + |
| 122 | + |
| 123 | +class Pytree: |
| 124 | + """> Abstract base class which registers a class with JAX's `Pytree` |
| 125 | + system. |
| 126 | +
|
| 127 | + Users who mixin this ABC for class definitions are required to |
| 128 | + implement `flatten` below. In turn, instances of the class gain |
| 129 | + access to a large set of utility functions for working with `Pytree` |
| 130 | + data, as well as the ability to use `jax.tree_util` Pytree |
| 131 | + functionality. |
| 132 | + """ |
| 133 | + |
| 134 | + def __init_subclass__(cls, **kwargs): |
| 135 | + jtu.register_pytree_node( |
| 136 | + cls, |
| 137 | + cls.flatten, |
| 138 | + cls.unflatten, |
| 139 | + ) |
| 140 | + |
| 141 | + @abc.abstractmethod |
| 142 | + def flatten(self) -> Tuple[Tuple, Tuple]: |
| 143 | + pass |
| 144 | + |
| 145 | + @classmethod |
| 146 | + def unflatten(cls, data, xs): |
| 147 | + return cls(*data, *xs) |
| 148 | + |
| 149 | + @classmethod |
| 150 | + def new(cls, *args, **kwargs): |
| 151 | + return cls(*args, **kwargs) |
| 152 | + |
| 153 | + |
| 154 | +############################################################ |
| 155 | +# Staging a function to a closed Jaxpr, for interpretation # |
| 156 | +############################################################ |
| 157 | + |
| 158 | + |
| 159 | +def get_shaped_aval(x): |
| 160 | + return jc.raise_to_shaped(jc.get_aval(x)) |
| 161 | + |
| 162 | + |
| 163 | +# This will create a cache entry for a function with abstract |
| 164 | +# arguments (each has a dtype and shape). |
| 165 | +# It is useful to cache the result, because you may stage the same |
| 166 | +# function multiple times - if you write and compose multiple transformations. |
| 167 | +@lu.cache |
| 168 | +def cached_stage_dynamic(flat_fun, in_avals): |
| 169 | + # https://github.com/google/jax/blob/1c66ac532b4ef2ad64f9e0859ede329f4fbd0041/jax/_src/interpreters/partial_eval.py#L2268-L2305 |
| 170 | + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) |
| 171 | + |
| 172 | + typed_jaxpr = jc.ClosedJaxpr(jaxpr, consts) |
| 173 | + return typed_jaxpr |
| 174 | + |
| 175 | + |
| 176 | +def stage(f): |
| 177 | + """Returns a function that stages a function to a ClosedJaxpr.""" |
| 178 | + |
| 179 | + def wrapped(*args, **kwargs): |
| 180 | + fun = lu.wrap_init(f, kwargs) |
| 181 | + flat_args, in_tree = jtu.tree_flatten(args) |
| 182 | + flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) |
| 183 | + flat_avals = safe_map(get_shaped_aval, flat_args) |
| 184 | + typed_jaxpr = cached_stage_dynamic(flat_fun, tuple(flat_avals)) |
| 185 | + return typed_jaxpr, (flat_args, in_tree, out_tree) |
| 186 | + |
| 187 | + return wrapped |
| 188 | + |
| 189 | + |
| 190 | +####################### |
| 191 | +# Forward interpreter # |
| 192 | +####################### |
| 193 | + |
| 194 | +VarOrLiteral = Union[jc.Var, jc.Literal] |
| 195 | + |
| 196 | + |
| 197 | +@dataclasses.dataclass |
| 198 | +class Environment(Pytree): |
| 199 | + """Keeps track of variables and their values during interpretation.""" |
| 200 | + |
| 201 | + env: HashableDict[jc.Var, Value] |
| 202 | + |
| 203 | + def flatten(self): |
| 204 | + return (self.env,), () |
| 205 | + |
| 206 | + @classmethod |
| 207 | + def new(cls): |
| 208 | + return Environment(hashable_dict()) |
| 209 | + |
| 210 | + def read(self, var: VarOrLiteral) -> Value: |
| 211 | + if isinstance(var, jc.Literal): |
| 212 | + return var.val |
| 213 | + else: |
| 214 | + return self.env.get(var.count) |
| 215 | + |
| 216 | + def write(self, var: VarOrLiteral, cell: Value) -> Value: |
| 217 | + if isinstance(var, jc.Literal): |
| 218 | + return cell |
| 219 | + cur_cell = self.read(var) |
| 220 | + if isinstance(var, jc.DropVar): |
| 221 | + return cur_cell |
| 222 | + self.env[var.count] = cell |
| 223 | + return self.env[var.count] |
| 224 | + |
| 225 | + def __getitem__(self, var: VarOrLiteral) -> Value: |
| 226 | + return self.read(var) |
| 227 | + |
| 228 | + def __setitem__(self, key, val): |
| 229 | + raise ValueError( |
| 230 | + "Environments do not support __setitem__. Please use the " |
| 231 | + "`write` method instead." |
| 232 | + ) |
| 233 | + |
| 234 | + def __contains__(self, var: VarOrLiteral): |
| 235 | + if isinstance(var, jc.Literal): |
| 236 | + return True |
| 237 | + return var in self.env |
| 238 | + |
| 239 | + def copy(self): |
| 240 | + return copy.copy(self) |
| 241 | + |
| 242 | + |
| 243 | +############################### |
| 244 | +# Forward masking interpreter # |
| 245 | +############################### |
| 246 | + |
| 247 | + |
| 248 | +@dataclasses.dataclass |
| 249 | +class ForwardInterpreter(Pytree): |
| 250 | + def flatten(self): |
| 251 | + return (), () |
| 252 | + |
| 253 | + # This produces an instance of `Interpreter` |
| 254 | + # as a context manager - to allow us to control error stack traces, |
| 255 | + # if required. |
| 256 | + @classmethod |
| 257 | + @contextmanager |
| 258 | + def new(cls): |
| 259 | + try: |
| 260 | + yield ForwardInterpreter() |
| 261 | + except Exception as e: |
| 262 | + raise e |
| 263 | + |
| 264 | + def _eval_jaxpr_forward( |
| 265 | + self, |
| 266 | + jaxpr: jc.Jaxpr, |
| 267 | + consts: List[Value], |
| 268 | + flat_args: List[Value], |
| 269 | + ): |
| 270 | + env = Environment.new() |
| 271 | + safe_map(env.write, jaxpr.constvars, consts) |
| 272 | + safe_map(env.write, jaxpr.invars, flat_args) |
| 273 | + for eqn in jaxpr.eqns: |
| 274 | + invals = safe_map(env.read, eqn.invars) |
| 275 | + subfuns, params = eqn.primitive.get_bind_params(eqn.params) |
| 276 | + args = subfuns + invals |
| 277 | + custom_rule = masking_registry[eqn.primitive] |
| 278 | + outvals = eqn.primitive.bind(*args, **params) |
| 279 | + if not eqn.primitive.multiple_results: |
| 280 | + outvals = [outvals] |
| 281 | + safe_map(env.write, eqn.outvars, outvals) |
| 282 | + |
| 283 | + return safe_map(env.read, jaxpr.outvars) |
| 284 | + |
| 285 | + def run_interpreter(self, fn, args, **kwargs): |
| 286 | + def _inner(*args): |
| 287 | + return fn(*args, **kwargs) |
| 288 | + |
| 289 | + closed_jaxpr, (flat_args, _, out_tree) = stage(_inner)(*args) |
| 290 | + flat_mask_flags = jtu.tree_leaves(mask_flags) |
| 291 | + jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals |
| 292 | + flat_out = self._eval_jaxpr_forward(jaxpr, consts, flat_args) |
| 293 | + return jtu.tree_unflatten(out_tree(), flat_out) |
| 294 | + |
| 295 | + |
| 296 | +def forward(f: Callable): |
| 297 | + @functools.wraps(f) |
| 298 | + def wrapped(args, mask_flags): |
| 299 | + with ForwardInterpreter.new() as interpreter: |
| 300 | + return interpreter.run_interpreter(f, args, mask_flags) |
| 301 | + |
| 302 | + return wrapped |
0 commit comments