Skip to content

Commit

Permalink
对None等常量进行更名处理
Browse files Browse the repository at this point in the history
  • Loading branch information
wukan1986 committed Sep 7, 2024
1 parent e706670 commit 1cd8f8f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 16 deletions.
2 changes: 1 addition & 1 deletion expr_codegen/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.0"
__version__ = "0.8.1"
68 changes: 61 additions & 7 deletions expr_codegen/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class SympyTransformer(ast.NodeTransformer):

# 映射
funcs_map = {}
args_map = {}
# 由于None等常量无法在sympy中正确处理,只能改成Symbol变量
# !!!一定要在drop_symbols时排除
args_map = {'True': "_TRUE_", 'False': "_FALSE_", 'None': "_NONE_"}
targets_map = {} # 只对非下划线开头的生效

def config_map(self, funcs_map, args_map, targets_map):
Expand All @@ -35,11 +37,18 @@ def visit_Call(self, node):
node.func.id = self.funcs_map.get(node.func.id, node.func.id)
self.funcs_new.add(node.func.id)
# 提取参数名
for arg in node.args:
for i, arg in enumerate(node.args):
if isinstance(arg, ast.Name):
self.args_old.add(arg.id)
arg.id = self.args_map.get(arg.id, arg.id)
self.args_new.add(arg.id)
if isinstance(arg, ast.Constant):
old_arg_value = str(arg.value)
if old_arg_value in self.args_map:
new_arg_value = self.args_map.get(old_arg_value, old_arg_value)
self.args_old.add(old_arg_value)
node.args[i] = ast.Name(new_arg_value, ctx=ast.Load())
self.args_new.add(new_arg_value)

self.generic_visit(node)
return node
Expand All @@ -60,24 +69,41 @@ def __visit_Assign(self, target: expr):
# 记录修改的变量名,之后会使用到
self.args_map[old_target_id] = new_target_id

if isinstance(target, ast.Constant):
old_target_value = str(target.value)
if old_target_value in self.args_map:
new_target_value = self.args_map.get(old_target_value, old_target_value)
self.args_old.add(old_target_value)
target = ast.Name(new_target_value, ctx=ast.Load())
self.args_new.add(new_target_value)

return target

def visit_Assign(self, node):
# 调整位置,支持循环赋值
# _A = _A+1 调整成 _A_001 = _A_000 + 1
self.generic_visit(node)

# 提取输出变量名
for target in node.targets:
for i, target in enumerate(node.targets):
if isinstance(target, ast.Tuple):
for t in target.elts:
self.__visit_Assign(t)
for j, t in enumerate(target.elts):
target.elts[j] = self.__visit_Assign(t)
else:
self.__visit_Assign(target)
node.targets[i] = self.__visit_Assign(target)

# 处理 alpha=close 这种情况
if isinstance(node.value, ast.Name):
self.args_old.add(node.value.id)
node.value.id = self.args_map.get(node.value.id, node.value.id)
self.args_new.add(node.value.id)
if isinstance(node.value, ast.Constant):
old_node_value = str(node.value.value)
if old_node_value in self.args_map:
new_node_value = self.args_map.get(old_node_value, old_node_value)
self.args_old.add(old_node_value)
node.value = ast.Name(new_node_value, ctx=ast.Load())
self.args_new.add(new_node_value)

return node

Expand All @@ -87,11 +113,18 @@ def visit_Compare(self, node):
self.args_old.add(node.left.id)
node.left.id = self.args_map.get(node.left.id, node.left.id)
self.args_new.add(node.left.id)
for com in node.comparators:
for i, com in enumerate(node.comparators):
if isinstance(com, ast.Name):
self.args_old.add(com.id)
com.id = self.args_map.get(com.id, com.id)
self.args_new.add(com.id)
if isinstance(com, ast.Constant):
old_com_value = str(com.value)
if old_com_value in self.args_map:
new_com_value = self.args_map.get(old_com_value, old_com_value)
self.args_old.add(old_com_value)
node.comparators[i] = ast.Name(new_com_value, ctx=ast.Load())
self.args_new.add(new_com_value)

# OPEN==CLOSE,要转成Eq
if isinstance(node.ops[0], ast.Eq):
Expand Down Expand Up @@ -146,6 +179,20 @@ def visit_BinOp(self, node):
self.args_old.add(node.right.id)
node.right.id = self.args_map.get(node.right.id, node.right.id)
self.args_new.add(node.right.id)
if isinstance(node.left, ast.Constant):
old_node_value = str(node.left.value)
if old_node_value in self.args_map:
new_node_value = self.args_map.get(old_node_value, old_node_value)
self.args_old.add(old_node_value)
node.left = ast.Name(new_node_value, ctx=ast.Load())
self.args_new.add(new_node_value)
if isinstance(node.right, ast.Constant):
old_node_value = str(node.right.value)
if old_node_value in self.args_map:
new_node_value = self.args_map.get(old_node_value, old_node_value)
self.args_old.add(old_node_value)
node.right = ast.Name(new_node_value, ctx=ast.Load())
self.args_new.add(new_node_value)

self.generic_visit(node)
return node
Expand All @@ -156,6 +203,13 @@ def visit_UnaryOp(self, node):
self.args_old.add(node.operand.id)
node.operand.id = self.args_map.get(node.operand.id, node.operand.id)
self.args_new.add(node.operand.id)
if isinstance(node.operand, ast.Constant):
old_operand_value = str(node.operand.value)
if old_operand_value in self.args_map:
new_operand_value = self.args_map.get(old_operand_value, old_operand_value)
self.args_old.add(old_operand_value)
node.operand = ast.Name(new_operand_value, ctx=ast.Load())
self.args_new.add(new_operand_value)

self.generic_visit(node)
return node
Expand Down
3 changes: 2 additions & 1 deletion expr_codegen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def drop_symbols(self):
l2 = [set()]
s = set()
for i in reversed(l1):
s = s | i
# 这三变量需要排除
s = s | i - {'_NONE_', '_TRUE_', '_FALSE_'}
l2.append(s)
l2 = list(reversed(l2))

Expand Down
3 changes: 3 additions & 0 deletions expr_codegen/pandas/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ from loguru import logger # noqa

_DATE_ = '{{ date }}'
_ASSET_ = '{{ asset }}'
_NONE_ = None
_TRUE_ = True
_FALSE_ = False

{%-for row in extra_codes %}
{{ row-}}
Expand Down
3 changes: 3 additions & 0 deletions expr_codegen/polars/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ from polars_ta.prefix.cdl import * # noqa

_DATE_ = '{{ date }}'
_ASSET_ = '{{ asset }}'
_NONE_ = None
_TRUE_ = True
_FALSE_ = False

{%-for row in extra_codes %}
{{ row-}}
Expand Down
18 changes: 11 additions & 7 deletions tests/formula_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@

# 再也不怕出现有环了
source = """
_A = 1+2
_A = _A+_A
_B = _A+2
_C = _A+_B
_A = _C+4
D = _A
_A = True+None
# F = 1>None
# _A = Add(1,True)
#
# _A = _A+_A
# _B = _A+2
# _C = _A+_B
#
# D = _A
"""

tree = ast.parse(source_replace(source))
Expand All @@ -42,7 +46,7 @@
'delta': 'ts_delta',
'delay': 'ts_delay',
}
args_map = {}
args_map = {'True': "TRUE", 'False': "FALSE", 'None': "NONE"}
targets_map = {'_A': '_12'}

t.config_map(funcs_map, args_map, targets_map)
Expand Down

0 comments on commit 1cd8f8f

Please sign in to comment.