Skip to content

Commit 6221041

Browse files
committed
Fix a minor bug
1 parent 32fdf11 commit 6221041

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

dill/_dill.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -1184,12 +1184,13 @@ def _repr_dict(obj):
11841184

11851185
@register(dict)
11861186
def save_module_dict(pickler, obj):
1187-
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
1187+
_is_dill = is_dill(pickler, child=False)
1188+
if _is_dill and obj == pickler._main.__dict__ and \
11881189
not (pickler._session and pickler._first_pass):
11891190
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
11901191
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
11911192
logger.trace(pickler, "# D1")
1192-
elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
1193+
elif (not _is_dill) and (obj == _main_module.__dict__):
11931194
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
11941195
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
11951196
logger.trace(pickler, "# D3")
@@ -1199,12 +1200,37 @@ def save_module_dict(pickler, obj):
11991200
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
12001201
pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8'))
12011202
logger.trace(pickler, "# D4")
1203+
elif _is_dill and id(obj) in pickler._globals_cache:
1204+
logger.trace(pickler, "D5: %s", _repr_dict(obj)) # obj
1205+
# This is a globals dictionary that was partially copied, but not fully saved.
1206+
# Save the dictionary again to ensure that everything is there.
1207+
globs_copy = pickler._globals_cache[id(obj)]
1208+
pickler.write(pickler.get(pickler.memo[id(globs_copy)][0]))
1209+
pickler._batch_setitems(iter(obj.items()))
1210+
del pickler._globals_cache[id(obj)]
1211+
pickler.memo[id(obj)] = (pickler.memo.pop(id(globs_copy))[0], obj)
1212+
logger.trace(pickler, "# D5")
12021213
else:
12031214
logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj
1204-
if is_dill(pickler, child=False) and pickler._session:
1215+
if _is_dill and pickler._session:
12051216
# we only care about session the first pass thru
12061217
pickler._first_pass = False
1207-
StockPickler.save_dict(pickler, obj)
1218+
1219+
from pickle import EMPTY_DICT, MARK, DICT, SETITEM
1220+
if pickler.bin:
1221+
pickler.write(EMPTY_DICT)
1222+
else: # proto 0 -- can't use EMPTY_DICT
1223+
pickler.write(MARK + DICT)
1224+
1225+
# StockPickler.save_dict(pickler, obj)
1226+
pickler.memoize(obj)
1227+
# add __name__ first
1228+
if '__name__' in obj:
1229+
pickler.save('__name__')
1230+
pickler.save(obj['__name__'])
1231+
pickler.write(SETITEM)
1232+
pickler._batch_setitems(obj.items())
1233+
12081234
logger.trace(pickler, "# D2")
12091235
return
12101236

@@ -1797,7 +1823,11 @@ def save_function(pickler, obj):
17971823
postproc_list = []
17981824

17991825
globs = None
1800-
if _recurse:
1826+
if id(obj.__globals__) in pickler.memo:
1827+
# It is possible that the globals dictionary itself is also being
1828+
# pickled directly.
1829+
globs = globs_copy = obj.__globals__
1830+
elif _recurse:
18011831
# recurse to get all globals referred to by obj
18021832
from .detect import globalvars
18031833
globs_copy = globalvars(obj, recurse=True, builtin=True)

dill/tests/_globals_dummy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
x = 3
44

55
def h():
6-
print(x)
6+
return x
77

88
def g():
9-
h()
9+
return h()
1010

dill/tests/test_functions.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def test_code_object():
132132
except Exception as error:
133133
raise Exception("failed to construct code object with format version {}".format(version)) from error
134134

135+
135136
def test_shared_globals():
136137
import dill, _globals_dummy as f, sys
137138

@@ -140,17 +141,22 @@ def test_shared_globals():
140141
assert f.g.__globals__ is f.h.__globals__
141142
assert g.__globals__ is h.__globals__
142143
assert f.g.__globals__ is g.__globals__
144+
assert g() == h() == 3
143145

144146
del sys.modules['_globals_dummy']
145147

146148
g, h = dill.copy((f.g, f.h), recurse=recurse)
147149
assert f.g.__globals__ is f.h.__globals__
148150
assert g.__globals__ is h.__globals__
149151
assert f.g.__globals__ is not g.__globals__
152+
assert g() == h() == 3
153+
g1, g, g2 = dill.copy((f.__dict__, f.g, f.g.__globals__), recurse=recurse)
154+
assert g1 is g.__globals__
155+
assert g1 is g2
150156

151157
sys.modules['_globals_dummy'] = f
152158

153-
159+
154160
if __name__ == '__main__':
155161
test_functions()
156162
test_issue_510()

0 commit comments

Comments
 (0)