Skip to content

Commit

Permalink
自变量自动注册
Browse files Browse the repository at this point in the history
  • Loading branch information
wukan1986 committed Feb 2, 2024
1 parent a644e55 commit 5d41d88
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 58 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ https://exprcodegen.streamlit.app
│ alpha101.txt # WorldQuant Alpha101示例,可复制到`streamlit`应用
│ demo_cn.py # 中文注释示例。演示如何将表达式转换成代码
│ demo_exec_pl.py # 演示调用转换后代码并绘图
│ demo_transformer.py # 演示将第三方表达式转成内部表达式
│ output.py # 结果输出。可不修改代码,直接被其它项目导入
│ show_tree.py # 画表达式树形图。可用于分析对比优化结果
│ sympy_define.py # 符号定义,由于太多地方重复使用到,所以统一提取到此处
Expand Down
8 changes: 3 additions & 5 deletions examples/demo_cn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# ====================
import inspect

from expr_codegen.codes import sources_to_asts
from expr_codegen.expr import dict_to_exprs
from expr_codegen.codes import sources_to_exprs
from expr_codegen.tool import ExprTool

# 导入OPEN等特征
Expand Down Expand Up @@ -58,12 +57,11 @@ def _code_block_():

# 读取源代码,转成字符串
source = inspect.getsource(_code_block_)
raw, assigns = sources_to_asts(source)
assigns_dict = dict_to_exprs(assigns, globals().copy())
raw, exprs_dict = sources_to_exprs(globals().copy(), source)

# 生成代码
tool = ExprTool()
codes, G = tool.all(assigns_dict, style='polars', template_file='template.py.j2',
codes, G = tool.all(exprs_dict, style='polars', template_file='template.py.j2',
replace=True, regroup=True, format=True,
date='date', asset='asset',
# 复制了需要使用的函数,还复制了最原始的表达式
Expand Down
2 changes: 1 addition & 1 deletion expr_codegen/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.8"
__version__ = "0.4.9"
35 changes: 31 additions & 4 deletions expr_codegen/codes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import re

from expr_codegen.expr import register_symbols, dict_to_exprs


class SympyTransformer(ast.NodeTransformer):
"""将ast转换成Sympy要求的格式"""
Expand Down Expand Up @@ -100,6 +102,16 @@ def visit_BinOp(self, node):
)
# 这种情况要处理吗?
# (OPEN < CLOSE)*(OPEN < CLOSE)

if isinstance(node.left, ast.Name):
self.args_old.add(node.left.id)
node.left.id = self.args_map.get(node.left.id, node.left.id)
self.args_new.add(node.left.id)
if isinstance(node.right, ast.Name):
self.args_old.add(node.right.id)
node.right.id = self.args_map.get(node.right.id, node.right.id)
self.args_new.add(node.right.id)

self.generic_visit(node)
return node

Expand All @@ -108,11 +120,15 @@ def sources_to_asts(*sources):
"""输入多份源代码"""
raw = []
assigns = {}
funcs_new, args_new, targets_new = set(), set(), set()
for arg in sources:
r, a = _source_to_asts(arg)
r, a, funcs_, args_, targets_ = _source_to_asts(arg)
raw.append(r)
assigns.update(a)
return '\n'.join(raw), assigns
funcs_new.update(funcs_)
args_new.update(args_)
targets_new.update(targets_)
return '\n'.join(raw), assigns, funcs_new, args_new, targets_new


def source_replace(source):
Expand All @@ -126,7 +142,8 @@ def source_replace(source):
def _source_to_asts(source):
"""源代码"""
tree = ast.parse(source_replace(source))
SympyTransformer().visit(tree)
t = SympyTransformer()
t.visit(tree)

raw = []
assigns = []
Expand All @@ -145,7 +162,7 @@ def _source_to_asts(source):
if isinstance(node, (ast.Import, ast.ImportFrom)):
raw.append(node)
continue
return raw_to_code(raw), assigns_to_dict(assigns)
return raw_to_code(raw), assigns_to_dict(assigns), t.funcs_new, t.args_new, t.targets_new


def assigns_to_dict(assigns):
Expand All @@ -156,3 +173,13 @@ def assigns_to_dict(assigns):
def raw_to_code(raw):
"""导入语句转字符列表"""
return '\n'.join([ast.unparse(a) for a in raw])


def sources_to_exprs(globals_, *sources):
"""将源代码转换成表达式"""
raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources)
register_symbols(funcs_new, globals_, is_function=True)
register_symbols(args_new, globals_, is_function=False)
register_symbols(targets_new, globals_, is_function=False)
exprs_dict = dict_to_exprs(assigns, globals_)
return raw, exprs_dict
44 changes: 25 additions & 19 deletions expr_codegen/expr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import re
from functools import reduce

Expand All @@ -15,30 +14,37 @@
CL_SET = {CL_TUP} # 整列集合


def keys_to_Symbol(keys):
"""中间变量自注册。有了它可以将中间变量注册,方便实现多步计算"""
syms = symbols(','.join(keys.keys()), cls=Symbol, seq=True)
syms = {s.name: s for s in syms}
return syms
def is_symbol(x, globals_):
s = globals_.get(x, None)
if s is None:
return False
if isinstance(s, Symbol):
# OPEN
return True
if type(s) is type:
# Eq
return issubclass(s, Basic)
return False


def function_to_Function(globals_):
"""函数自注册
!!! 非常重要,有几百个函数,如果每个都要写symbols()注册过于麻烦,这里按要求将导入的函数自动注册
"""
funcs = {k: v for k, v in globals_.items() if inspect.isfunction(v)}
funcs = {k: v for k, v in funcs.items() if v.__module__ not in ('inspect', 'sympy.core.symbol')}
syms = symbols(','.join(funcs.keys()), cls=Function, seq=True)
def register_symbols(syms, globals_, is_function: bool):
"""注册sympy中需要使用的符号"""
# Eq等已经是sympy的符号不需注册
syms = [s for s in syms if not is_symbol(s, globals_)]
if len(syms) == 0:
return globals_

if is_function:
# 函数被注册后不能再调用,所以一定要用globals().copy()
syms = symbols(','.join(syms), cls=Function, seq=True)
else:
syms = symbols(','.join(syms), cls=Symbol, seq=True)
syms = {s.name: s for s in syms}
return syms
globals_.update(syms)
return globals_


def dict_to_exprs(exprs_src, globals_):
# 注册中间符号
globals_.update(keys_to_Symbol(exprs_src))
# !!! 函数自动注册
globals_.update(function_to_Function(globals_))

exprs_src = {k: safe_eval(v, globals_) for k, v in exprs_src.items()}
return exprs_src

Expand Down
2 changes: 1 addition & 1 deletion expr_codegen/polars/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main(df: pl.DataFrame):
{% endfor %}

# drop intermediate columns
df = df.drop(columns=list(filter(lambda x: re.search(r"^_x_\d+", x), df.columns)))
df = df.select(pl.exclude(r'^_x_\d+$'))

# shrink
df = df.select(cs.all().shrink_dtype())
Expand Down
39 changes: 11 additions & 28 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from itertools import islice

import streamlit as st
import sympy
from black import format_str, Mode
from loguru import logger
from streamlit_ace import st_ace
from sympy import numbered_symbols, Symbol, FunctionClass

import expr_codegen
from expr_codegen.codes import sources_to_asts
from expr_codegen.expr import replace_exprs, dict_to_exprs
from expr_codegen.codes import sources_to_exprs
from expr_codegen.expr import replace_exprs
from expr_codegen.tool import ExprTool


Expand Down Expand Up @@ -59,12 +58,6 @@ def list_to_string(items, n):
date_name = st.text_input('日期字段名', 'date')
asset_name = st.text_input('资产字段名', 'asset')

factors_text_area = st.text_area(label='新增预定义因子', value="""# Alpha101基础因子
OPEN, HIGH, LOW, CLOSE, VOLUME, AMOUNT,
RETURNS, VWAP, CAP,
ADV5, ADV10, ADV15, ADV20, ADV30, ADV40, ADV50, ADV60, ADV81, ADV120, ADV150, ADV180,
SECTOR, INDUSTRY, SUBINDUSTRY,""")

# 生成代码
style = st.radio('代码风格', ('polars', 'pandas/cudf.pandas'))
if style == 'polars':
Expand Down Expand Up @@ -92,16 +85,10 @@ def list_to_string(items, n):
version: {expr_codegen.__version__}
""")

with st.expander(label="预定义**算子**"):
st.write('如缺算子,可以在issue中申请添加,或下载代码进行二次开发')

# 本可以不用写这么复杂,但为了证明可以动态加载和执行,所以演示一下
module = __import__('examples.sympy_define', fromlist=['*'])

source = inspect.getsource(module)
st.code(source)
# 执行
exec(source, globals())
# 本可以不用写这么复杂,但为了证明可以动态加载和执行,所以演示一下
module = __import__('examples.sympy_define', fromlist=['*'])
source = inspect.getsource(module)
exec(source, globals())

st.subheader('自定义表达式')
all_symbols, all_functions = get_symbols_functions(module)
Expand All @@ -128,24 +115,20 @@ def list_to_string(items, n):

if st.button('生成代码'):
with st.spinner('生成中,请等待...'):
# 自定义注册到全局变量
sympy.var(factors_text_area)

# eval处理,转成字典
raw, assigns = sources_to_asts(exprs_src)
assigns_dict = dict_to_exprs(assigns, globals().copy())
raw, exprs_dict = sources_to_exprs(globals().copy(), exprs_src)

if is_pre_opt:
logger.info('事前 表达式 替换')
assigns_dict = replace_exprs(assigns_dict)
exprs_dict = replace_exprs(exprs_dict)

tool = ExprTool()

logger.info('表达式 抽取 合并')
exprs_dst, syms_dst = tool.merge(**assigns_dict)
exprs_dst, syms_dst = tool.merge(**exprs_dict)

logger.info('提取公共表达式')
tool.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), symbols_redu=assigns_dict.keys())
tool.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), symbols_redu=exprs_dict.keys())

logger.info('生成有向无环图')
exprs_ldl, G = tool.dag(merge=True)
Expand All @@ -154,7 +137,7 @@ def list_to_string(items, n):
exprs_ldl.optimize(back_opt=is_back_opt, chain_opt=is_chain_opt)

logger.info('代码生成')
source = codegen(exprs_ldl, assigns_dict, syms_dst,
source = codegen(exprs_ldl, exprs_dict, syms_dst,
filename='template.py.j2',
date=date_name, asset=asset_name,
extra_codes=(raw,))
Expand Down

0 comments on commit 5d41d88

Please sign in to comment.