Skip to content

[WIP] Stan backend #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 51 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0ee4580
Initial implementation of stan translator
Dashadower Jul 13, 2022
73db7e8
Update testing.py
hyunjimoon Jul 21, 2022
6246c21
change model input and output signature
hyunjimoon Jul 21, 2022
b651530
Add LOOKUP implementation. Refactor AST walkers to classes
Dashadower Jul 24, 2022
64eb140
Automatically identify stock variables and initial values
Dashadower Jul 24, 2022
4f1e37e
WIP static RNG function
Dashadower Jul 24, 2022
e912553
Update test models and initial value generation
Dashadower Jul 31, 2022
e7fa259
Allow non-default outcome variable passage
Dashadower Jul 31, 2022
8c9f767
Add missing semicolon
Dashadower Jul 31, 2022
ce5bc65
Add float to auxillary name walker
Dashadower Jul 31, 2022
f37c720
Update explanation on outcome signature
hyunjimoon Jul 31, 2022
2903470
Readme for PR preparation
hyunjimoon Aug 1, 2022
02971a3
Update readme.md
hyunjimoon Aug 1, 2022
46df7c6
Update readme.md
hyunjimoon Aug 1, 2022
3569efd
Merge branch 'master' of https://github.com/Dashadower/pysd into stan…
Aug 1, 2022
12039a9
Extend to Hierarchical models
Aug 7, 2022
1903312
Substitute match with isinstance.
Dashadower Aug 11, 2022
af542c4
Autoformat with black
Dashadower Aug 11, 2022
c744e17
Minor update to logic
Dashadower Aug 11, 2022
b6305e5
Add demand_supply model workflow in notebook file
hyunjimoon Aug 18, 2022
76eea02
Relocate notebook file
hyunjimoon Aug 18, 2022
ae761d3
Update return signature of ode function and change deprecated stan fu…
Dashadower Aug 18, 2022
dc961b6
Change output signature of ode call
Dashadower Aug 18, 2022
8cd38f2
Update notebook to include full flow
hyunjimoon Aug 18, 2022
7f9e56c
Merge branch 'stan-backend' of github.com:hyunjimoon/pysd into stan-b…
hyunjimoon Aug 18, 2022
76f5485
Add tables to workflow
hyunjimoon Aug 18, 2022
c921d21
Update notebook for relation input (issue #16)
hyunjimoon Aug 19, 2022
ad76f2a
WIP stan model interface
Dashadower Aug 22, 2022
b3c5a1f
Two not compile tested stan files
hyunjimoon Aug 27, 2022
1d9545e
data2draws implementation
Dashadower Aug 29, 2022
a009a47
Updates to data2draws
Dashadower Aug 29, 2022
161e55e
Check if variables with priors exist in the ODE function declaration
Aug 31, 2022
9d8e1f2
Update codegen to match stan 2.30
Dashadower Aug 31, 2022
5584248
Update lookup codegen logic and create cmdstan model
Dashadower Aug 31, 2022
09c6086
Change filename
Dashadower Aug 31, 2022
e2d4df8
Six manual comments each for draws2data and data2draws
hyunjimoon Aug 31, 2022
3227876
Add files via upload
hyunjimoon Aug 31, 2022
fd5e4c1
Six comments each for interface upgrade
hyunjimoon Aug 31, 2022
2ff05de
implement draws2data
hyunjimoon Sep 1, 2022
4ba4e81
Merge branch 'stan-backend' into stan-backend-d2d
hyunjimoon Sep 1, 2022
ad4b855
Merge pull request #1 from hyunjimoon/stan-backend-d2d
hyunjimoon Sep 1, 2022
5f1b79a
Add constraints to parameters and better input data integration
Dashadower Sep 1, 2022
d57897c
Change stan model directory
Dashadower Sep 1, 2022
c0f88e1
Update data block
Dashadower Sep 1, 2022
53febea
Codegen cleanup and draws2data implementation
Dashadower Sep 6, 2022
bd75d72
Rearrange vensim file structure
hyunjimoon Sep 11, 2022
83440b0
Update draws2data
Dashadower Sep 11, 2022
156b6ac
Merge branch 'stan-backend' of https://github.com/hyunjimoon/pysd int…
Dashadower Sep 11, 2022
f638ac5
Three checks with draws2data, data2draws function
hyunjimoon Sep 16, 2022
53a2576
Bug fix and analysis notebook for preypredator
hyunjimoon Sep 18, 2022
c022c72
Merge remote-tracking branch 'upstream/master' into stan-backend
Dashadower Sep 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ tests/htmlcov/
.idea/*
docs/_build/*
docs/tables/*.csv
venv/
tests/.coverage*
300 changes: 300 additions & 0 deletions pysd/builders/stan/ast_walker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from typing import Union, List, Iterable, Dict, Tuple, Any
from itertools import chain
from dataclasses import dataclass, field
from .utilities import IndentedString
from pysd.translators.structures.abstract_model import (
AbstractComponent,
AbstractElement,
AbstractModel,
AbstractSection,
)

from pysd.translators.structures.abstract_expressions import *


class BaseNodeWaler:
def walk(self, ast_node):
raise NotImplementedError


class AuxNameWalker(BaseNodeWaler):
def walk(self, ast_node) -> List[str]:
if isinstance(ast_node, int):
return []
elif isinstance(ast_node, float):
return []
elif isinstance(ast_node, ArithmeticStructure):
return list(
chain.from_iterable(
[self.walk(argument) for argument in ast_node.arguments]
)
)
elif isinstance(ast_node, ReferenceStructure):
return [ast_node.reference]
elif isinstance(ast_node, CallStructure):
return list(
chain.from_iterable(
[self.walk(argument) for argument in ast_node.arguments]
)
)
elif isinstance(ast_node, IntegStructure):
return self.walk(ast_node.flow) + self.walk(ast_node.initial)
elif isinstance(ast_node, InlineLookupsStructure):
return self.walk(ast_node.lookups)


@dataclass
class LookupCodegenWalker(BaseNodeWaler):
generated_lookup_function_names: Dict[str, str] = field(
default_factory=dict
)
# This dict holds the generated function names of each individual lookup function.
# Key is x + y + x_limits + y_limits, value is function name
n_lookups = 0
code = IndentedString()

@staticmethod
def get_lookup_keyname(lookup_node: LookupsStructure):
return (
lookup_node.x
+ lookup_node.y
+ lookup_node.x_limits
+ lookup_node.y_limits
)

def walk(self, ast_node, node_name: str) -> None:
if isinstance(ast_node, InlineLookupsStructure):
self.walk(ast_node.lookups, node_name)
elif isinstance(ast_node, LookupsStructure):
assert (
ast_node.type == "interpolate"
), "Type of Lookup must be 'interpolate'"
function_name = f"lookupFunc__{node_name}"
self.generated_lookup_function_names[
node_name
] = function_name
self.n_lookups += 1
self.code += f"real {function_name}(real x){{\n"
self.code.indent_level += 1
# Enter function body
self.code += f"// x {ast_node.x_limits} = {ast_node.x}\n"
self.code += f"// y {ast_node.y_limits} = {ast_node.y}\n"
self.code += "real slope;\n"
self.code += "real intercept;\n\n"
n_intervals = len(ast_node.x)
for lookup_index in range(n_intervals):
if lookup_index == 0:
continue
if lookup_index == 1:
self.code += f"if(x <= {ast_node.x[lookup_index]}){{\n"
else:
self.code += f"else if(x <= {ast_node.x[lookup_index]}){{\n"

self.code.indent_level += 1
# enter conditional body
self.code += f"intercept = {ast_node.y[lookup_index - 1]};\n"
self.code += f"slope = ({ast_node.y[lookup_index]} - {ast_node.y[lookup_index - 1]}) / ({ast_node.x[lookup_index]} - {ast_node.x[lookup_index - 1]});\n"
self.code += f"return intercept + slope * (x - {ast_node.x[lookup_index - 1]});\n"
self.code.indent_level -= 1
# exit conditional body
self.code += "}\n"

# Handle out-of-bounds input
self.code += f"return {ast_node.y[-1]};\n"

self.code.indent_level -= 1
# exit function body
self.code += "}\n\n"
else:
return None


@dataclass
class BlockCodegenWalker(BaseNodeWaler):
lookup_function_names: Dict[str, str]

def walk(self, ast_node) -> str:

if isinstance(ast_node, int):
return f"{ast_node}"
elif isinstance(ast_node, float):
return f"{ast_node}"
elif isinstance(ast_node, str):
return ast_node
elif isinstance(ast_node, ArithmeticStructure):
# ArithmeticStructure consists of chained arithmetic expressions.
# We parse them one by one into a single expression
output_string = ""
last_argument_index = len(ast_node.arguments) - 1
for index, argument in enumerate(ast_node.arguments):
output_string += self.walk(argument)
if index < last_argument_index:
output_string += " "
output_string += ast_node.operators[index]
output_string += " "
return output_string

elif isinstance(ast_node, ReferenceStructure):
# ReferenceStructure denotes invoking the value of another variable
# Subscripts are ignored for now
if ast_node.reference in self.lookup_function_names:
return self.lookup_function_names[ast_node.reference]
return ast_node.reference

elif isinstance(ast_node, CallStructure):
output_string = ""
function_name = self.walk(ast_node.function)
if function_name == "min":
function_name = "fmin"
elif function_name == "max":
function_name = "fmax"
elif function_name == "xidz":
assert (
len(ast_node.arguments) == 3
), "number of arguments for xidz must be 3"
arg1 = self.walk(ast_node.arguments[0])
arg2 = self.walk(ast_node.arguments[1])
arg3 = self.walk(ast_node.arguments[2])
output_string += (
f" (abs({arg2}) <= 1e-6) ? {arg3} : ({arg1}) / ({arg2})"
)
return output_string
elif function_name == "zidz":
assert (
len(ast_node.arguments) == 2
), "number of arguments for zidz must be 2"
arg1 = self.walk(ast_node.arguments[0])
arg2 = self.walk(ast_node.arguments[1])
output_string += (
f" (abs({arg2}) <= 1e-6) ? 0 : ({arg1}) / ({arg2})"
)
return output_string
elif function_name == "ln":
# natural log in stan is just log
function_name = "log"

output_string += function_name
output_string += "("
output_string += ", ".join(
[self.walk(argument) for argument in ast_node.arguments]
)
output_string += ")"

return output_string

elif isinstance(ast_node, IntegStructure):
return self.walk(ast_node.flow)

elif isinstance(ast_node, InlineLookupsStructure):
lookup_func_name = self.lookup_function_names[
LookupCodegenWalker.get_lookup_keyname(ast_node.lookups)
]
return f"{lookup_func_name}({self.walk(ast_node.argument)})"

else:
raise Exception("Got unknown node", ast_node)


@dataclass
class InitialValueCodegenWalker(BlockCodegenWalker):
variable_ast_dict: Dict[str, AbstractSyntax]
lookup_function_names: Dict[Union[str, Tuple], str]

def walk(self, ast_node):
if isinstance(ast_node, IntegStructure):
return self.walk(ast_node.initial)

elif isinstance(ast_node, SmoothStructure):
return self.walk(ast_node.initial)

elif isinstance(ast_node, ReferenceStructure):
if ast_node.reference in self.variable_ast_dict:
return self.walk(self.variable_ast_dict[ast_node.reference])
else:
return super().walk(ast_node)

elif isinstance(ast_node, ArithmeticStructure):
# ArithmeticStructure consists of chained arithmetic expressions.
# We parse them one by one into a single expression
output_string = ""
last_argument_index = len(ast_node.arguments) - 1
for index, argument in enumerate(ast_node.arguments):
output_string += self.walk(argument)
if index < last_argument_index:
output_string += " "
output_string += ast_node.operators[index]
output_string += " "
return output_string
else:
return super().walk(ast_node)


@dataclass
class RNGCodegenWalker(InitialValueCodegenWalker):
variable_ast_dict: Dict[str, AbstractSyntax]
lookup_function_names: Dict[Tuple, str]
total_timestep: int

def walk(self, ast_node) -> str:
if isinstance(ast_node, CallStructure):
function_name = self.walk(ast_node.function)
if function_name in (
"random_beta",
"random_binomial",
"random_binomial",
"random_exponential",
"random_gamma",
"random_normal",
"random_poisson",
):
argument_codegen = [
self.walk(argument) for argument in ast_node.arguments
]
return self.rng_codegen(function_name, argument_codegen)
else:
return super().walk(ast_node)

elif isinstance(ast_node, IntegStructure):
raise Exception(
"RNG function arguments cannot contain stock variables which change with time and thus must be constant!"
)

elif isinstance(ast_node, SmoothStructure):
raise Exception(
"RNG function arguments cannot contain stock variables which change with time and thus must be constant!"
)

elif isinstance(ast_node, ReferenceStructure):
if ast_node.reference in self.variable_ast_dict:
return self.walk(ast_node.reference)
else:
return super().walk(ast_node)

elif isinstance(ast_node, ArithmeticStructure):
# ArithmeticStructure consists of chained arithmetic expressions.
# We parse them one by one into a single expression
output_string = ""
last_argument_index = len(ast_node.arguments) - 1
for index, argument in enumerate(ast_node.arguments):
output_string += self.walk(argument)
if index < last_argument_index:
output_string += " "
output_string += ast_node.operators[index]
output_string += " "
return output_string

else:
return super().walk(ast_node)

def rng_codegen(self, rng_type: str, arguments: List[Any]):
if rng_type == "random_normal":
lower, upper, mean, std, _ = arguments
return f"fmin(fmax(normal_rng({mean}, {std}), {lower}), {upper})"
elif rng_type == "random_uniform":
lower, upper, _ = arguments
return f"uniform_rng({lower}, {upper})"
elif rng_type == "random_poisson":
lower, upper, _lambda, offset, multiply, _ = arguments
return f"fmin(fmax(fma(poisson_rng({_lambda}), {multiply}, {offset}), {lower}), {upper})"
else:
raise Exception(f"RNG function {rng_type} not implemented")
15 changes: 15 additions & 0 deletions pysd/builders/stan/iteration_counter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
static int iteration_counter = 0;

// https://discourse.mc-stan.org/t/generating-random-numbers-in-the-model/3608
// https://discourse.mc-stan.org/t/is-it-possible-to-access-the-iteration-step-number-inside-a-stan-program/1871/6
// https://mc-stan.org/docs/cmdstan-guide/using-external-cpp-code.html
// https://discourse.mc-stan.org/t/hoping-for-some-guidance-help-with-implementing-custom-log-likelihood-and-gradient-for-research-project-details-below/24598/14
namespace vensim_ode_model_namespace {
inline int get_current_iteration(std::ostream* pstream__) {
return iteration_counter;
}

inline void increment_iteration(std::ostream* pstream__) {
++iteration_counter;
}
}
Loading