Skip to content

Commit

Permalink
put simple node check back
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Oct 13, 2022
1 parent b679644 commit a33401e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
23 changes: 20 additions & 3 deletions auto_walrus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SEP_SYMBOLS = frozenset(('(', ')', ',', ':'))
# name, lineno, col_offset, end_lineno, end_col_offset
Token = Tuple[str, int, int, int, int]

SIMPLE_NODE = (ast.Name, ast.Constant)
ENDS_WITH_COMMENT = re.compile(r'#.*$')


Expand All @@ -25,6 +25,22 @@ def name_lineno_coloffset(tokens: Token) -> tuple[str, int, int]:
return (tokens[0], tokens[1], tokens[2])


def is_simple_test(node: ast.AST) -> bool:
return (
isinstance(node, SIMPLE_NODE)
or (
isinstance(node, ast.Compare)
and isinstance(node.left, SIMPLE_NODE)
and (
all(
isinstance(_node, SIMPLE_NODE)
for _node in node.comparators
)
)
)
)


def record_name_lineno_coloffset(
node: ast.Name,
end_lineno: int | None = None,
Expand Down Expand Up @@ -186,9 +202,10 @@ def visit_function_def(
if isinstance(_node, ast.Assign):
process_assign(_node, assignments, related_vars)
elif isinstance(_node, ast.If):
ifs.update(process_if(_node, in_body_vars))
if is_simple_test(_node.test):
ifs.update(process_if(_node, in_body_vars))
for __node in _node.orelse:
if isinstance(__node, ast.If):
if isinstance(__node, ast.If) and is_simple_test(__node.test):
ifs.update(process_if(__node, in_body_vars))

sorted_names = sorted(names, key=lambda x: (x[1], x[2]))
Expand Down
22 changes: 14 additions & 8 deletions tests/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,20 @@ def test_rewrite(src: str, expected: str) -> None:
' if a:\n'
' print(a)\n'
' a = 2\n',
'n = 10\n'
'if True:\n'
' pass\n'
'elif foo(a := n+1):\n'
' print(n)\n',
'n = 10\n'
'if n > np.sin(foo.bar.quox):\n'
' print(n)\n',
'def foo():\n'
' n = 10\n'
' if True:\n'
' pass\n'
' elif foo(a := n+1):\n'
' print(n)\n',
'def foo():\n'
' n = 10\n'
' if n > np.sin(foo.bar.quox):\n'
' print(n)\n',
'def foo():\n'
' n = 10\n'
' if True or n > 3:\n'
' print(n)\n',
],
)
def test_noop(src: str) -> None:
Expand Down

0 comments on commit a33401e

Please sign in to comment.