Skip to content

Commit 4815bc2

Browse files
committed
[SPARK-6660][MLLIB] pythonToJava doesn't recognize object arrays
davies Author: Xiangrui Meng <[email protected]> Closes apache#5318 from mengxr/SPARK-6660 and squashes the following commits: 0f66ec2 [Xiangrui Meng] recognize object arrays ad8c42f [Xiangrui Meng] add a test for SPARK-6660
1 parent 757b2e9 commit 4815bc2

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,10 @@ private[spark] object SerDe extends Serializable {
11131113
iter.flatMap { row =>
11141114
val obj = unpickle.loads(row)
11151115
if (batched) {
1116-
obj.asInstanceOf[JArrayList[_]].asScala
1116+
obj match {
1117+
case list: JArrayList[_] => list.asScala
1118+
case arr: Array[_] => arr
1119+
}
11171120
} else {
11181121
Seq(obj)
11191122
}

python/pyspark/mllib/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
else:
3737
import unittest
3838

39+
from pyspark.mllib.common import _to_java_object_rdd
3940
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
4041
DenseMatrix, Vectors, Matrices
4142
from pyspark.mllib.regression import LabeledPoint
@@ -641,6 +642,13 @@ def test_idf_model(self):
641642
idf = model.idf()
642643
self.assertEqual(len(idf), 11)
643644

645+
646+
class SerDeTest(PySparkTestCase):
647+
def test_to_java_object_rdd(self): # SPARK-6660
648+
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
649+
self.assertEqual(_to_java_object_rdd(data).count(), 10)
650+
651+
644652
if __name__ == "__main__":
645653
if not _have_scipy:
646654
print "NOTE: Skipping SciPy tests as it does not seem to be installed"

0 commit comments

Comments
 (0)