diff --git a/lineapy/transformer/node_transformer.py b/lineapy/transformer/node_transformer.py index 691b5253d..843704657 100644 --- a/lineapy/transformer/node_transformer.py +++ b/lineapy/transformer/node_transformer.py @@ -569,21 +569,11 @@ def visit_Subscript(self, node: ast.Subscript) -> CallNode: args = [self.visit(node.value)] index = node.slice args.append(self.visit(index)) - if isinstance(node.ctx, ast.Load): - return self.tracer.call( - self.tracer.lookup_node(GET_ITEM), - self.get_source(node), - *args, - ) - elif isinstance(node.ctx, ast.Del): - raise NotImplementedError( - "Subscript with ctx=ast.Del() not supported." - ) - else: - raise ValueError( - "Subscript with ctx=ast.Store() should have been handled by" - " visit_Assign." - ) + return self.tracer.call( + self.tracer.lookup_node(GET_ITEM), + self.get_source(node), + *args, + ) def visit_Attribute(self, node: ast.Attribute) -> CallNode: diff --git a/lineapy/transformer/source_giver.py b/lineapy/transformer/source_giver.py index de13459b3..a7286cf29 100644 --- a/lineapy/transformer/source_giver.py +++ b/lineapy/transformer/source_giver.py @@ -2,9 +2,6 @@ class SourceGiver: - def __init__(self): - pass - def transform(self, nodes: ast.Module) -> None: """ This call should only happen once asttoken has run its magic @@ -26,31 +23,3 @@ def transform(self, nodes: ast.Module) -> None: node.end_col_offset = node.last_token.end[1] # type: ignore # if isinstance(node, ast.ListComp): node.col_offset = node.first_token.start[1] # type: ignore - - def transform_inhouse(self, nodes: ast.Module) -> None: - curr_line_start: int = -1 - curr_line_offset: int = -1 - prev_line_start: int = -1 - prev_line_offset: int = -1 - prev_node = None - node: ast.AST - # TODO check if the ast type is a Module instead of simply relying on mypy - for node in ast.walk(nodes): - if not hasattr(node, "lineno"): - continue - - curr_line_start = node.lineno - curr_line_offset = node.col_offset - if prev_node is not None: - prev_node.end_lineno = max( # type: ignore - prev_line_start, curr_line_start - 1 - ) - if prev_line_start == curr_line_start: - prev_node.end_col_offset = max( - prev_line_offset, curr_line_offset - 1 - ) - else: - prev_node.end_col_offset = -1 - prev_node = node - prev_line_start = curr_line_start - prev_line_offset = curr_line_offset diff --git a/requirements.txt b/requirements.txt index 0a3bac3eb..4495a6a1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ alabaster==0.7.12 altair==4.2.0 -anyio==3.4.0 +anyio==3.5.0 appnope==0.1.2 -argcomplete==2.0.0 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 astor==0.8.1 @@ -27,7 +26,7 @@ coveralls==3.3.1 cramjam==2.5.0 cycler==0.11.0 debugpy==1.5.1 -decorator==5.1.0 +decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.4 docopt==0.6.2 @@ -39,16 +38,16 @@ fastparquet==0.7.2 filelock==3.4.2 flake8==4.0.1 fonttools==4.28.5 -fsspec==2021.11.1 +fsspec==2022.1.0 graphviz==0.19.1 greenlet==1.1.2 -identify==2.4.1 +identify==2.4.4 idna==3.3 imagesize==1.3.0 importlib-metadata==4.2.0 importlib-resources==5.4.0 iniconfig==1.1.1 -ipykernel==6.6.1 +ipykernel==6.7.0 ipython==7.31.0 ipython-genutils==0.2.0 isort==5.10.1 @@ -56,13 +55,13 @@ jedi==0.18.1 Jinja2==3.0.3 joblib==1.1.0 json5==0.9.6 -jsonschema==4.3.3 +jsonschema==4.4.0 jupyter-client==7.1.0 jupyter-core==4.9.1 -jupyter-server==1.13.1 -jupyterlab==3.2.5 +jupyter-server==1.13.3 +jupyterlab==3.2.8 jupyterlab-pygments==0.1.2 -jupyterlab-server==2.10.2 +jupyterlab-server==2.10.3 kiwisolver==1.3.2 MarkupSafe==2.0.1 matplotlib==3.5.1 @@ -70,23 +69,23 @@ matplotlib-inline==0.1.3 mccabe==0.6.1 mistune==0.8.4 mock==4.0.3 -mypy==0.910 +mypy==0.931 mypy-extensions==0.4.3 -nbclassic==0.3.4 -nbclient==0.5.9 +nbclassic==0.3.5 +nbclient==0.5.10 nbconvert==6.4.0 nbformat==5.1.3 nbval==0.9.6 nest-asyncio==1.5.4 networkx==2.6.3 nodeenv==1.6.0 -notebook==6.4.6 +notebook==6.4.7 numpy==1.21.5 packaging==21.3 pandas==1.3.5 pandocfilters==1.5.0 parso==0.8.3 -path==16.2.0 +path==16.3.0 path.py==12.5.0 pathspec==0.9.0 pdbpp==0.10.3 @@ -106,7 +105,7 @@ pycodestyle==2.8.0 pycparser==2.21 pydantic==1.9.0 pyflakes==2.4.0 -Pygments==2.11.1 +Pygments==2.11.2 PyOpenGL==3.1.5 pyparsing==3.0.6 pyrepl==0.9.0 @@ -123,7 +122,7 @@ pytz==2021.3 PyYAML==6.0 pyzmq==22.3.0 requests==2.27.1 -rich==10.16.2 +rich==11.0.0 scikit-learn==1.0.2 scipy==1.7.3 scour==0.38.2 @@ -154,10 +153,10 @@ tomli==1.2.3 toolz==0.11.2 tornado==6.1 traitlets==5.1.1 -typed-ast==1.4.3 -types-PyYAML==6.0.1 +typed-ast==1.5.1 +types-PyYAML==6.0.3 typing_extensions==4.0.1 -urllib3==1.26.7 +urllib3==1.26.8 virtualenv==20.13.0 wcwidth==0.2.5 webencodings==0.5.1 diff --git a/setup.py b/setup.py index 70950630c..4d3b6949a 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def version(path): "flake8", "fastparquet", "syrupy==1.4.5", - "mypy<0.920", + "mypy", "isort", "pytest", "matplotlib", diff --git a/tests/unit/transformer/test_node_transformer.py b/tests/unit/transformer/test_node_transformer.py new file mode 100644 index 000000000..14502940d --- /dev/null +++ b/tests/unit/transformer/test_node_transformer.py @@ -0,0 +1,211 @@ +# Setting up for mypy to ignore this file +# This file uses mocks which have dynamically defined functions. +# this does not sit well with mypy who needs to know what functions a class has +# type:ignore +import ast +import sys + +import asttokens +import pytest +from mock import MagicMock, patch + +from lineapy.transformer.node_transformer import NodeTransformer, transform +from lineapy.transformer.source_giver import SourceGiver + + +def _get_ast_node(code): + node = ast.parse(code) + if sys.version_info < (3, 8): # give me endlines! + asttokens.ASTTokens(code, parse=False, tree=node) + SourceGiver().transform(node) + + return node + + +@patch( + "lineapy.transformer.node_transformer.NodeTransformer", +) +def test_transform_fn(nt_mock: MagicMock): + """ + Test that the transform function calls the NodeTransformer + """ + mocked_tracer = MagicMock() + source_location = MagicMock() + transform("x = 1", source_location, mocked_tracer) + nt_mock.assert_called_once_with("x = 1", source_location, mocked_tracer) + mocked_tracer.db.commit.assert_called_once() + # TODO - test that source giver is called only for 3.7 and below + + +class TestNodeTransformer: + nt: NodeTransformer + + basic_tests_list = [ + ("a[2:3]", "Slice", 2), + ("a[2:3:2]", "Slice", 2), + ("a[2:3]", "Subscript", 2), + ("a.x", "Attribute", 1), + ("{'a': 1}", "Dict", 1), + ("a < b", "Compare", 1), + ("a not in []", "Compare", 3), + ("a + b", "BinOp", 1), + ("a or b", "BoolOp", 1), + ("[1,2]", "List", 1), + # set is xfailing right now + # ("{1,2}", "Set", 1), + # tuple eventually calls tracer.call but we've mocked out the whole thing + ("(1,2)", "Tuple", 0), + ("not True", "UnaryOp", 1), + ("assert True", "Assert", 1), + ("fn(*args)", "Expr", 1), + ("fn(*args)", "Call", 1), + ("fn(*args)", "Starred", 1), + ("fn(*args)", "Name", 1), + ("print(*'mystring')", "Starred", 1), + # calls tracer.trace_import but this list checks for tracer.call calls + ("import math", "Import", 0), + # calls tracer.trace_import but this list checks for tracer.call calls + # TODO - importfrom calls transform_utils which should really be mocked out and + # tested on their own, in true spirit of unit testing + ("from math import sqrt", "ImportFrom", 0), + ("a, b = (1,2)", "Assign", 2), + ("lambda x: x + 10", "Lambda", 1), + ] + + # TODO handle visit_Name as its own test + # TODO module + basic_test_ids = [ + "slice", + "slice_with_step", + "subscript", + "attribute", + "dict", + "compare", + "compare_notin", + "binop", + "boolop", + "list", + # "set", + "tuple", + "unaryop", + "assert", + "expr", + "call", + "starred", + "name", + "starred_str", + "import", + "import_from", + "assign_tuple", + "lambda", + ] + + if sys.version_info < (3, 8): + basic_tests_list += (("10", "Num", 0),) + # extslice does not call tracer.call but it contains a slice node. + # that along with subscript will result in two calls + basic_tests_list += (("a[:,3]", "ExtSlice", 2),) + basic_test_ids += ["num"] + basic_test_ids += ["extslice"] + else: + # this will break with 3.7 + basic_tests_list += (("10", "Constant", 0),) + basic_tests_list += (("a[:,3]", "Slice", 2),) + basic_test_ids += ["constant"] + basic_test_ids += ["slice_with_ext"] + + basic_tests = ("code, visitor, call_count", basic_tests_list) + + @pytest.fixture(autouse=True) + def before_everything(self): + nt = NodeTransformer( + "", MagicMock(), MagicMock() # SourceCodeLocation(0, 0, 0, 0) + ) + assert nt is not None + self.nt = nt + + def test_lambda_executes_as_expression(self): + self.nt._exec_expression = MagicMock() + self.nt._exec_statement = MagicMock() + + # this inits an ast.Module containing one expression whose value is a ast.lambda + test_node = _get_ast_node("lambda x: x + 10") + lambda_node = test_node.body[0].value + self.nt.generic_visit(lambda_node) + self.nt._exec_statement.assert_not_called() + self.nt._exec_expression.assert_called_once() + + def test_assign_executes(self): + test_node = _get_ast_node("a = 10") + self.nt.visit_Assign = MagicMock() + self.nt.visit(test_node.body[0]) + self.nt.visit_Assign.assert_called_once_with(test_node.body[0]) + + def test_assign_calls_tracer_assign(self): + self.nt.get_source = MagicMock() + test_node = _get_ast_node("a = 10") + tracer = self.nt.tracer + self.nt.visit(test_node.body[0]) + tracer.assign.assert_called_once_with("a", tracer.literal.return_value) + + @pytest.mark.parametrize( + "code", + ["a[3]=10", "a.x =10"], + ids=["assign_subscript", "assign_attribute"], + ) + def test_assign_subscript_attribute_calls_tracer_assign(self, code): + self.nt.get_source = MagicMock() + test_node = _get_ast_node(code) + tracer = self.nt.tracer + self.nt.visit(test_node.body[0]) + tracer.call.assert_called_once_with( + tracer.lookup_node.return_value, + self.nt.get_source.return_value, + tracer.lookup_node.return_value, + tracer.literal.return_value, + tracer.literal.return_value, + ) + + def test_visit_delete_executes(self): + test_node = _get_ast_node("del a") + with pytest.raises(NotImplementedError): + self.nt.visit_Delete(test_node.body[0]) + + self.nt.visit_Delete = MagicMock() + self.nt.visit(test_node.body[0]) + self.nt.visit_Delete.assert_called_once_with(test_node.body[0]) + + @pytest.mark.parametrize( + "code", ["del a[3]", "del a.x"], ids=["delitem", "delattr"] + ) + def test_visit_delete_subscript_attribute_calls_tracer_call(self, code): + self.nt.get_source = MagicMock() + test_node = _get_ast_node(code) + tracer = self.nt.tracer + self.nt.visit(test_node.body[0]) + tracer.call.assert_called_once_with( + tracer.lookup_node.return_value, + self.nt.get_source.return_value, + tracer.lookup_node.return_value, + tracer.literal.return_value, + ) + + # catch all for any ast nodes that do not have if conditions and/or very little logic + + @pytest.mark.parametrize(*basic_tests, ids=basic_test_ids) + def test_code_visited_calls_tracer_call(self, code, visitor, call_count): + self.nt._get_code_from_node = MagicMock() + test_node = _get_ast_node(code) + self.nt.visit(test_node) + # doing this so that we can select which function in tracer gets called. + # might be overkill though so leaving it at this + tracer_fn = getattr(self.nt.tracer, "call") + assert tracer_fn.call_count == call_count + + @pytest.mark.parametrize(*basic_tests, ids=basic_test_ids) + def test_code_visits_right_visitor(self, code, visitor, call_count): + test_node = _get_ast_node(code) + self.nt.__setattr__("visit_" + visitor, MagicMock()) + nt_visitor = self.nt.__getattribute__("visit_" + visitor) + self.nt.visit(test_node) + nt_visitor.assert_called_once() diff --git a/tests/unit/test_source_giver.py b/tests/unit/transformer/test_source_giver.py similarity index 63% rename from tests/unit/test_source_giver.py rename to tests/unit/transformer/test_source_giver.py index 28d9a31cf..6e2d16500 100644 --- a/tests/unit/test_source_giver.py +++ b/tests/unit/transformer/test_source_giver.py @@ -6,6 +6,8 @@ from lineapy.transformer.source_giver import SourceGiver +# from astpretty import pprint + @pytest.mark.parametrize( "code,lineno", @@ -25,17 +27,18 @@ def test_source_giver_adds_end_lineno(code, lineno): if sys.version_info >= (3, 8): pytest.skip("SourceGiver not invoked for Python 3.8+") import asttokens - from astpretty import pprint tree = ast.parse(code) + # ensure that the end_lineno is not available and fetching it raises exceptions with pytest.raises(AttributeError): print(tree.body[0].end_lineno) + # now we invoke the SourceGiver and add end_linenos in 2 steps - first we run the tree thr asttokens asttokens.ASTTokens(code, parse=False, tree=tree) - pprint(tree) - print(tree.body[0].last_token) + # double check that the line numbers cooked up by asttokens are correct assert tree.body[0].last_token.end[0] == lineno + # and in step 2, run the tree thr SourceGiver and copy the asttokens's token values + # so that the tree looks like 3.8+ tree with all the end_linenos etc SourceGiver().transform(tree) assert tree.body[0].end_lineno == lineno - print(tree)