Skip to content

Commit 3b8b90c

Browse files
Add custom logic for pickling dynamic imports.
Add test cases, special case Ellipsis and NotImplemented. Use custom logic in lieu of imp.find_module to properly follow subimports. For example sklearn.tree was spuriously treated as a dynamic module.
1 parent e47f29f commit 3b8b90c

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

cloudpickle/cloudpickle.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import operator
4646
import io
47+
import imp
4748
import pickle
4849
import struct
4950
import sys
@@ -134,8 +135,19 @@ def save_module(self, obj):
134135
"""
135136
Save a module as an import
136137
"""
138+
mod_name = obj.__name__
139+
# If module is successfully found then it is not a dynamically created module
140+
try:
141+
_find_module(mod_name)
142+
is_dynamic = False
143+
except ImportError:
144+
is_dynamic = True
145+
137146
self.modules.add(obj)
138-
self.save_reduce(subimport, (obj.__name__,), obj=obj)
147+
if is_dynamic:
148+
self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)), obj=obj)
149+
else:
150+
self.save_reduce(subimport, (obj.__name__,), obj=obj)
139151
dispatch[types.ModuleType] = save_module
140152

141153
def save_codeobject(self, obj):
@@ -313,7 +325,7 @@ def extract_func_data(self, func):
313325
return (code, f_globals, defaults, closure, dct, base_globals)
314326

315327
def save_builtin_function(self, obj):
316-
if obj.__module__ is "__builtin__":
328+
if obj.__module__ == "__builtin__":
317329
return self.save_global(obj)
318330
return self.save_function(obj)
319331
dispatch[types.BuiltinFunctionType] = save_builtin_function
@@ -584,11 +596,20 @@ def save_file(self, obj):
584596
self.save(retval)
585597
self.memoize(obj)
586598

599+
def save_ellipsis(self, obj):
600+
self.save_reduce(_gen_ellipsis, ())
601+
602+
def save_not_implemented(self, obj):
603+
self.save_reduce(_gen_not_implemented, ())
604+
587605
if PY3:
588606
dispatch[io.TextIOWrapper] = save_file
589607
else:
590608
dispatch[file] = save_file
591609

610+
dispatch[type(Ellipsis)] = save_ellipsis
611+
dispatch[type(NotImplemented)] = save_not_implemented
612+
592613
"""Special functions for Add-on libraries"""
593614
def inject_addons(self):
594615
"""Plug in system. Register additional pickling functions if modules already loaded"""
@@ -620,6 +641,12 @@ def subimport(name):
620641
return sys.modules[name]
621642

622643

644+
def dynamic_subimport(name, vars):
645+
mod = imp.new_module(name)
646+
mod.__dict__.update(vars)
647+
sys.modules[name] = mod
648+
return mod
649+
623650
# restores function attributes
624651
def _restore_attr(obj, attr):
625652
for key, val in attr.items():
@@ -663,6 +690,11 @@ def _genpartial(func, args, kwds):
663690
kwds = {}
664691
return partial(func, *args, **kwds)
665692

693+
def _gen_ellipsis():
694+
return Ellipsis
695+
696+
def _gen_not_implemented():
697+
return NotImplemented
666698

667699
def _fill_function(func, globals, defaults, dict):
668700
""" Fills in the rest of function data into the skeleton function object
@@ -698,6 +730,18 @@ def _make_skel_func(code, closures, base_globals = None):
698730
None, None, closure)
699731

700732

733+
def _find_module(mod_name):
734+
"""
735+
Iterate over each part instead of calling imp.find_module directly.
736+
This function is able to find submodules (e.g. sickit.tree)
737+
"""
738+
path = None
739+
for part in mod_name.split('.'):
740+
if path is not None:
741+
path = [path]
742+
file, path, description = imp.find_module(part, path)
743+
return file, path, description
744+
701745
"""Constructors for 3rd party libraries
702746
Note: These can never be renamed due to client compatibility issues"""
703747

tests/cloudpickle_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import division
2+
import imp
23
import unittest
34
import pytest
45
import pickle
56
import sys
67
import functools
78
import platform
9+
import textwrap
810

911
try:
1012
# try importing numpy and scipy. These are not hard dependencies and
@@ -252,6 +254,26 @@ def f(self, x):
252254
self.assertEqual(g.im_class.__name__, F.f.im_class.__name__)
253255
# self.assertEqual(g(F(), 1), 2) # still fails
254256

257+
def test_module(self):
258+
self.assertEqual(pickle, pickle_depickle(pickle))
259+
260+
def test_dynamic_module(self):
261+
mod = imp.new_module('mod')
262+
code = '''
263+
x = 1
264+
def f(y):
265+
return x + y
266+
'''
267+
exec(textwrap.dedent(code), mod.__dict__)
268+
mod2 = pickle_depickle(mod)
269+
self.assertEqual(mod.x, mod2.x)
270+
self.assertEqual(mod.f(5), mod2.f(5))
271+
272+
def test_Ellipsis(self):
273+
self.assertEqual(Ellipsis, pickle_depickle(Ellipsis))
274+
275+
def test_NotImplemented(self):
276+
self.assertEqual(NotImplemented, pickle_depickle(NotImplemented))
255277

256278
if __name__ == '__main__':
257279
unittest.main()

0 commit comments

Comments
 (0)