Skip to content

Commit 8e6ab49

Browse files
lasersonmarkhamstra
authored andcommitted
SPARK-1917: fix PySpark import of scipy.special functions
https://issues.apache.org/jira/browse/SPARK-1917 Author: Uri Laserson <[email protected]> Closes apache#866 from laserson/SPARK-1917 and squashes the following commits: d947e8c [Uri Laserson] Added test for scipy.special importing 1798bbd [Uri Laserson] SPARK-1917: fix PySpark import of scipy.special Conflicts: python/pyspark/tests.py
1 parent c69207f commit 8e6ab49

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

python/pyspark/cloudpickle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ def _change_cell_value(cell, newval):
933933
Note: These can never be renamed due to client compatibility issues"""
934934

935935
def _getobject(modname, attribute):
936-
mod = __import__(modname)
936+
mod = __import__(modname, fromlist=[attribute])
937937
return mod.__dict__[attribute]
938938

939939
def _generateImage(size, mode, str_rep):

python/pyspark/tests.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
from pyspark.java_gateway import SPARK_HOME
3434
from pyspark.serializers import read_int
3535

36+
_have_scipy = False
37+
try:
38+
import scipy.sparse
39+
_have_scipy = True
40+
except:
41+
# No SciPy, but that's okay, we'll skip those tests
42+
pass
43+
3644

3745
class PySparkTestCase(unittest.TestCase):
3846

@@ -234,5 +242,22 @@ def test_termination_sigterm(self):
234242
from signal import SIGTERM
235243
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
236244

245+
246+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
247+
class SciPyTests(PySparkTestCase):
248+
"""General PySpark tests that depend on scipy """
249+
250+
def test_serialize(self):
251+
from scipy.special import gammaln
252+
x = range(1, 5)
253+
expected = map(gammaln, x)
254+
observed = self.sc.parallelize(x).map(gammaln).collect()
255+
self.assertEqual(expected, observed)
256+
257+
237258
if __name__ == "__main__":
259+
if not _have_scipy:
260+
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
238261
unittest.main()
262+
if not _have_scipy:
263+
print "NOTE: SciPy tests were skipped as it does not seem to be installed"

0 commit comments

Comments
 (0)