|
| 1 | +import ast |
1 | 2 | import importlib |
2 | 3 | import os |
3 | 4 | import re |
4 | 5 | from types import ModuleType |
5 | | -from typing import Dict, Iterator |
6 | | - |
| 6 | +from typing import Any, Dict, Iterator, Set |
7 | 7 |
|
| 8 | +LINE_IDENTIFIER = "_" |
8 | 9 | _TYPE_ERROR_MSG = "The provided expression must be an str (editing) or a bool (filtering), but got {}." |
9 | 10 |
|
10 | 11 |
|
| 12 | +def iter_identifiers(expr: str) -> Iterator[str]: |
| 13 | + for node in iter_asts(ast.parse(expr, mode="eval").body): |
| 14 | + if isinstance(node, ast.Name): |
| 15 | + yield node.id |
| 16 | + |
| 17 | + |
| 18 | +def iter_asts(node: ast.AST) -> Iterator[ast.AST]: |
| 19 | + """ |
| 20 | + Depth-first traversal of nodes |
| 21 | + """ |
| 22 | + yield node |
| 23 | + yield from ( |
| 24 | + name for child in ast.iter_child_nodes(node) for name in iter_asts(child) |
| 25 | + ) |
| 26 | + |
| 27 | + |
| 28 | +def auto_import_eval(expression: str, globals: Dict[str, Any]) -> Any: |
| 29 | + globals = globals.copy() |
| 30 | + encountered_name_errors: Set[str] = set() |
| 31 | + while True: |
| 32 | + try: |
| 33 | + return eval(expression, globals) |
| 34 | + except NameError as name_error: |
| 35 | + if str(name_error) in encountered_name_errors: |
| 36 | + raise |
| 37 | + encountered_name_errors.add(str(name_error)) |
| 38 | + match = re.match(r"name '([A-Za-z]+)'.*", str(name_error)) |
| 39 | + if match: |
| 40 | + module = match.group(1) |
| 41 | + globals[module] = importlib.import_module(module) |
| 42 | + continue |
| 43 | + |
| 44 | + |
11 | 45 | def edit(lines: Iterator[str], expression) -> Iterator[str]: |
12 | 46 | modules: Dict[str, ModuleType] = {} |
| 47 | + |
13 | 48 | for line in lines: |
14 | 49 | linesep = "" |
15 | 50 | if line.endswith(os.linesep): |
16 | 51 | linesep, line = os.linesep, line[: -len(os.linesep)] |
17 | | - globals = {"_": line, **modules} |
18 | | - try: |
19 | | - value = eval(expression, globals) |
20 | | - except NameError as name_error: |
21 | | - match = re.match(r"name '([A-Za-z]+)'.*", str(name_error)) |
22 | | - if match: |
23 | | - module = match.group(1) |
24 | | - else: |
25 | | - raise name_error |
26 | | - try: |
27 | | - modules[module] = importlib.import_module(module) |
28 | | - globals = {"_": line, **modules} |
29 | | - except: |
30 | | - raise name_error |
31 | | - value = eval(expression, globals) |
| 52 | + globals = {LINE_IDENTIFIER: line, **modules} |
| 53 | + value = auto_import_eval(expression, globals) |
32 | 54 | if isinstance(value, str): |
33 | 55 | yield value + linesep |
34 | 56 | elif isinstance(value, bool): |
|
0 commit comments