Skip to content

Commit a068771

Browse files
committed
first round of unit tests for node transformer.
1 parent 848efa2 commit a068771

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#
2+
#
3+
# type:ignore
4+
import ast
5+
6+
import pytest
7+
from mock import MagicMock
8+
9+
from lineapy.transformer.node_transformer import NodeTransformer
10+
11+
12+
class TestNodeTransformer:
13+
nt: NodeTransformer
14+
15+
@pytest.fixture(autouse=True)
16+
def before_everything(self):
17+
nt = NodeTransformer(
18+
"", MagicMock(), MagicMock() # SourceCodeLocation(0, 0, 0, 0)
19+
)
20+
assert nt is not None
21+
self.nt = nt
22+
23+
def test_lambda_executes_as_expression(self):
24+
self.nt._exec_expression = MagicMock()
25+
self.nt._exec_statement = MagicMock()
26+
27+
# this inits an ast.Module containing one expression whose value is a ast.lambda
28+
test_node = ast.parse("lambda x: x + 10")
29+
lambda_node = test_node.body[0].value
30+
self.nt.generic_visit(lambda_node)
31+
self.nt._exec_statement.assert_not_called()
32+
self.nt._exec_expression.assert_called_once()
33+
34+
def test_assign_executes(self):
35+
test_node = ast.parse("a = 10")
36+
self.nt.visit_Assign = MagicMock()
37+
self.nt.visit(test_node.body[0])
38+
self.nt.visit_Assign.assert_called_once_with(test_node.body[0])
39+
40+
def test_assign_calls_tracer_assign(self):
41+
self.nt.get_source = MagicMock()
42+
test_node = ast.parse("a = 10")
43+
tracer = self.nt.tracer
44+
self.nt.visit(test_node.body[0])
45+
# self.nt.visit_Assign.assert_called_once_with(test_node.body[0])
46+
tracer.assign.assert_called_once_with(
47+
"a", self.nt.tracer.literal.return_value
48+
)

tests/unit/test_source_giver.py renamed to tests/unit/transformer/test_source_giver.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import asttokens
88
import pytest
9-
from astpretty import pprint
109

1110
from lineapy.transformer.source_giver import SourceGiver
1211

12+
# from astpretty import pprint
13+
1314

1415
@pytest.mark.parametrize(
1516
"code,lineno",
@@ -29,14 +30,19 @@ def test_source_giver_adds_end_lineno(code, lineno):
2930
if sys.version_info >= (3, 8):
3031
pytest.skip("SourceGiver not invoked for Python 3.8+")
3132
tree = ast.parse(code)
33+
# ensure that the end_lineno is not available and fetching it raises exceptions
3234
with pytest.raises(AttributeError):
3335
print(tree.body[0].end_lineno)
3436

37+
# now we invoke the SourceGiver and add end_linenos in 2 steps - first we run the tree thr asttokens
3538
asttokens.ASTTokens(code, parse=False, tree=tree)
36-
pprint(tree)
37-
print(tree.body[0].last_token)
39+
# pprint(tree)
40+
# print(tree.body[0].last_token)
41+
# double check that the line numbers cooked up by asttokens are correct
3842
assert tree.body[0].last_token.end[0] == lineno
3943

44+
# and in step 2, run the tree thr SourceGiver and copy the asttokens's token values
45+
# so that the tree looks like 3.8+ tree with all the end_linenos etc
4046
SourceGiver().transform(tree)
4147
assert tree.body[0].end_lineno == lineno
42-
print(tree)
48+
# print(tree)

0 commit comments

Comments
 (0)