Skip to content

[SPARK-6263][MLLIB] Python MLlib API missing items: Utils #5707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ private[python] class PythonMLLibAPI extends Serializable {
minPartitions: Int): JavaRDD[LabeledPoint] =
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)

/**
* Loads and serializes vectors saved with `RDD#saveAsTextFile`.
* @param jsc Java SparkContext
* @param path file or directory path in any Hadoop-supported file system URI
* @return serialized vectors in a RDD
*/
def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] =
MLUtils.loadVectors(jsc.sc, path)

private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
Expand Down
43 changes: 43 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
from pyspark.mllib.util import LinearDataGenerator
from pyspark.mllib.util import MLUtils
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext
Expand Down Expand Up @@ -1290,6 +1291,48 @@ def func(rdd):
self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2)


class MLUtilsTests(MLlibTestCase):
def test_append_bias(self):
data = [2.0, 2.0, 2.0]
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)
self.assertEqual(type(ret), DenseVector)

def test_append_bias_with_vector(self):
data = Vectors.dense([2.0, 2.0, 2.0])
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)
self.assertEqual(type(ret), DenseVector)

def test_append_bias_with_sp_vector(self):
data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
# Returned value must be SparseVector
ret = MLUtils.appendBias(data)
self.assertEqual(ret, expected)
self.assertEqual(type(ret), SparseVector)

def test_load_vectors(self):
import shutil
data = [
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0]
]
temp_dir = tempfile.mkdtemp()
load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
try:
self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
ret = ret_rdd.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order of collect() is not guaranteed, so please sort "ret" and "data" and then compare to make the test robust.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I guess this didn't matter; I didn't notice the vectors were identical. Fine to keep it sorted though.

self.assertEqual(len(ret), 2)
self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
except:
self.fail()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to confirm the existence of vectors exported with saveAsTextFile, it should be removed on each time this test runs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misunderstood; please ignore my comment.

finally:
shutil.rmtree(load_vectors_path)


if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None):
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)

@staticmethod
def appendBias(data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the Scala version only operates on individual vectors, this one should not be a wrapper; it should do everything in Python. The reason is that callMLlibFunc requires the SparkContext and needs to operate on the driver. But since appendBias operates per-Row, it needs to be called on workers.

Also, please add doc. Feel free to copy from Scala doc.

"""
Returns a new vector with `1.0` (bias) appended to
the end of the input vector.
"""
vec = _convert_to_vector(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work if "data" is a Vector type? _convert_to_vector will leave it as a Vector type.

if isinstance(vec, SparseVector):
newIndices = np.append(vec.indices, len(vec))
newValues = np.append(vec.values, 1.0)
return SparseVector(len(vec) + 1, newIndices, newValues)
else:
return _convert_to_vector(np.append(vec.toArray(), 1.0))

@staticmethod
def loadVectors(sc, path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add doc. Feel free to copy from Scala doc

"""
Loads vectors saved using `RDD[Vector].saveAsTextFile`
with the default number of partitions.
"""
return callMLlibFunc("loadVectors", sc, path)


class Saveable(object):
"""
Expand Down