Skip to content

Commit eac1897

Browse files
authored
[mypyc] Fix using package imported inside a function (#9782)
This fixes an issue where this code resulted in an unbound local `p` error: ``` def f() -> None: import p.submodule print(p.x) # Runtime error here ``` We now look up `p` from the global modules dictionary instead of trying to use an undefined local variable.
1 parent 98eee40 commit eac1897

File tree

4 files changed

+114
-11
lines changed

4 files changed

+114
-11
lines changed

mypyc/irbuild/expression.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD
2727
from mypyc.primitives.registry import CFunctionDescription, builtin_names
2828
from mypyc.primitives.generic_ops import iter_op
29-
from mypyc.primitives.misc_ops import new_slice_op, ellipsis_op, type_op
29+
from mypyc.primitives.misc_ops import new_slice_op, ellipsis_op, type_op, get_module_dict_op
3030
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
3131
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
32-
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op
32+
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op, dict_get_item_op
3333
from mypyc.primitives.set_ops import new_set_op, set_add_op, set_update_op
3434
from mypyc.primitives.str_ops import str_slice_op
3535
from mypyc.primitives.int_ops import int_comparison_op_mapping
@@ -85,8 +85,21 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
8585
expr.node.name),
8686
expr.node.line)
8787

88-
# TODO: Behavior currently only defined for Var and FuncDef node types.
89-
return builder.read(builder.get_assignment_target(expr), expr.line)
88+
# TODO: Behavior currently only defined for Var, FuncDef and MypyFile node types.
89+
if isinstance(expr.node, MypyFile):
90+
# Load reference to a module imported inside function from
91+
# the modules dictionary. It would be closer to Python
92+
# semantics to access modules imported inside functions
93+
# via local variables, but this is tricky since the mypy
94+
# AST doesn't include a Var node for the module. We
95+
# instead load the module separately on each access.
96+
mod_dict = builder.call_c(get_module_dict_op, [], expr.line)
97+
obj = builder.call_c(dict_get_item_op,
98+
[mod_dict, builder.load_static_unicode(expr.node.fullname)],
99+
expr.line)
100+
return obj
101+
else:
102+
return builder.read(builder.get_assignment_target(expr), expr.line)
90103

91104
return builder.load_global(expr)
92105

mypyc/irbuild/statement.py

+4
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def transform_import(builder: IRBuilder, node: Import) -> None:
131131
# that mypy couldn't find, since it doesn't analyze module references
132132
# from those properly.
133133

134+
# TODO: Don't add local imports to the global namespace
135+
134136
# Miscompiling imports inside of functions, like below in import from.
135137
if as_name:
136138
name = as_name
@@ -140,8 +142,10 @@ def transform_import(builder: IRBuilder, node: Import) -> None:
140142

141143
# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
142144
mod_dict = builder.call_c(get_module_dict_op, [], node.line)
145+
# Get top-level module/package object.
143146
obj = builder.call_c(dict_get_item_op,
144147
[mod_dict, builder.load_static_unicode(base)], node.line)
148+
145149
builder.gen_method_call(
146150
globals, '__setitem__', [builder.load_static_unicode(name), obj],
147151
result_type=None, line=node.line)

mypyc/test-data/irbuild-basic.test

+51
Original file line numberDiff line numberDiff line change
@@ -3637,3 +3637,54 @@ L0:
36373637
c = r2
36383638
r3 = (c, b, a)
36393639
return r3
3640+
3641+
[case testLocalImportSubmodule]
3642+
def f() -> int:
3643+
import p.m
3644+
return p.x
3645+
[file p/__init__.py]
3646+
x = 1
3647+
[file p/m.py]
3648+
[out]
3649+
def f():
3650+
r0 :: dict
3651+
r1, r2 :: object
3652+
r3 :: bit
3653+
r4 :: str
3654+
r5 :: object
3655+
r6 :: dict
3656+
r7 :: str
3657+
r8 :: object
3658+
r9 :: str
3659+
r10 :: int32
3660+
r11 :: bit
3661+
r12 :: dict
3662+
r13 :: str
3663+
r14 :: object
3664+
r15 :: str
3665+
r16 :: object
3666+
r17 :: int
3667+
L0:
3668+
r0 = __main__.globals :: static
3669+
r1 = p.m :: module
3670+
r2 = load_address _Py_NoneStruct
3671+
r3 = r1 != r2
3672+
if r3 goto L2 else goto L1 :: bool
3673+
L1:
3674+
r4 = load_global CPyStatic_unicode_1 :: static ('p.m')
3675+
r5 = PyImport_Import(r4)
3676+
p.m = r5 :: module
3677+
L2:
3678+
r6 = PyImport_GetModuleDict()
3679+
r7 = load_global CPyStatic_unicode_2 :: static ('p')
3680+
r8 = CPyDict_GetItem(r6, r7)
3681+
r9 = load_global CPyStatic_unicode_2 :: static ('p')
3682+
r10 = CPyDict_SetItem(r0, r9, r8)
3683+
r11 = r10 >= 0 :: signed
3684+
r12 = PyImport_GetModuleDict()
3685+
r13 = load_global CPyStatic_unicode_2 :: static ('p')
3686+
r14 = CPyDict_GetItem(r12, r13)
3687+
r15 = load_global CPyStatic_unicode_3 :: static ('x')
3688+
r16 = CPyObject_GetAttr(r14, r15)
3689+
r17 = unbox(int, r16)
3690+
return r17

mypyc/test-data/run-imports.test

+42-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,45 @@ import testmodule
55

66
def f(x: int) -> int:
77
return testmodule.factorial(5)
8+
89
def g(x: int) -> int:
910
from welp import foo
1011
return foo(x)
12+
13+
def test_import_basics() -> None:
14+
assert f(5) == 120
15+
assert g(5) == 5
16+
17+
def test_import_submodule_within_function() -> None:
18+
import pkg.mod
19+
assert pkg.x == 1
20+
assert pkg.mod.y == 2
21+
22+
def test_import_as_submodule_within_function() -> None:
23+
import pkg.mod as mm
24+
assert mm.y == 2
25+
26+
# TODO: Don't add local imports to globals()
27+
#
28+
# def test_local_import_not_in_globals() -> None:
29+
# import nob
30+
# assert 'nob' not in globals()
31+
32+
def test_import_module_without_stub_in_function() -> None:
33+
# 'virtualenv' must not have a stub in typeshed for this test case
34+
import virtualenv # type: ignore
35+
# TODO: We shouldn't add local imports to globals()
36+
# assert 'virtualenv' not in globals()
37+
assert isinstance(virtualenv.__name__, str)
38+
39+
def test_import_as_module_without_stub_in_function() -> None:
40+
# 'virtualenv' must not have a stub in typeshed for this test case
41+
import virtualenv as vv # type: ignore
42+
assert 'virtualenv' not in globals()
43+
# TODO: We shouldn't add local imports to globals()
44+
# assert 'vv' not in globals()
45+
assert isinstance(vv.__name__, str)
46+
1147
[file testmodule.py]
1248
def factorial(x: int) -> int:
1349
if x == 0:
@@ -17,13 +53,12 @@ def factorial(x: int) -> int:
1753
[file welp.py]
1854
def foo(x: int) -> int:
1955
return x
20-
[file driver.py]
21-
from native import f, g
22-
print(f(5))
23-
print(g(5))
24-
[out]
25-
120
26-
5
56+
[file pkg/__init__.py]
57+
x = 1
58+
[file pkg/mod.py]
59+
y = 2
60+
[file nob.py]
61+
z = 3
2762

2863
[case testImportMissing]
2964
# The unchecked module is configured by the test harness to not be

0 commit comments

Comments
 (0)