Skip to content

Commit f2022fa

Browse files
MechCodermengxr
authored andcommitted
[SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils
It is useful to generate linear data for easy testing of linear models and in general. Scala already has it. This is just a wrapper around the Scala code. Author: MechCoder <[email protected]> Closes apache#6715 from MechCoder/generate_linear_input and squashes the following commits: 6182884 [MechCoder] Minor changes 8bda047 [MechCoder] Minor style fixes 0f1053c [MechCoder] [SPARK-8265] Add LinearDataGenerator to pyspark.mllib.utils
1 parent 2b1111d commit f2022fa

File tree

3 files changed

+86
-3
lines changed

3 files changed

+86
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import org.apache.spark.mllib.tree.loss.Losses
5151
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
5252
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
5353
import org.apache.spark.mllib.util.MLUtils
54+
import org.apache.spark.mllib.util.LinearDataGenerator
5455
import org.apache.spark.rdd.RDD
5556
import org.apache.spark.sql.DataFrame
5657
import org.apache.spark.storage.StorageLevel
@@ -972,7 +973,7 @@ private[python] class PythonMLLibAPI extends Serializable {
972973
def estimateKernelDensity(
973974
sample: JavaRDD[Double],
974975
bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
975-
return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
976+
new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
976977
points.asScala.toArray)
977978
}
978979

@@ -991,6 +992,35 @@ private[python] class PythonMLLibAPI extends Serializable {
991992
List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
992993
}
993994

995+
/**
996+
* Wrapper around the generateLinearInput method of LinearDataGenerator.
997+
*/
998+
def generateLinearInputWrapper(
999+
intercept: Double,
1000+
weights: JList[Double],
1001+
xMean: JList[Double],
1002+
xVariance: JList[Double],
1003+
nPoints: Int,
1004+
seed: Int,
1005+
eps: Double): Array[LabeledPoint] = {
1006+
LinearDataGenerator.generateLinearInput(
1007+
intercept, weights.asScala.toArray, xMean.asScala.toArray,
1008+
xVariance.asScala.toArray, nPoints, seed, eps).toArray
1009+
}
1010+
1011+
/**
1012+
* Wrapper around the generateLinearRDD method of LinearDataGenerator.
1013+
*/
1014+
def generateLinearRDDWrapper(
1015+
sc: JavaSparkContext,
1016+
nexamples: Int,
1017+
nfeatures: Int,
1018+
eps: Double,
1019+
nparts: Int,
1020+
intercept: Double): JavaRDD[LabeledPoint] = {
1021+
LinearDataGenerator.generateLinearRDD(
1022+
sc, nexamples, nfeatures, eps, nparts, intercept)
1023+
}
9941024
}
9951025

9961026
/**

python/pyspark/mllib/tests.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
from pyspark.mllib.stat import Statistics
5050
from pyspark.mllib.feature import Word2Vec
5151
from pyspark.mllib.feature import IDF
52-
from pyspark.mllib.feature import StandardScaler
53-
from pyspark.mllib.feature import ElementwiseProduct
52+
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
53+
from pyspark.mllib.util import LinearDataGenerator
5454
from pyspark.serializers import PickleSerializer
5555
from pyspark.streaming import StreamingContext
5656
from pyspark.sql import SQLContext
@@ -1019,6 +1019,24 @@ def collect(rdd):
10191019
self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
10201020

10211021

1022+
class LinearDataGeneratorTests(MLlibTestCase):
1023+
def test_dim(self):
1024+
linear_data = LinearDataGenerator.generateLinearInput(
1025+
intercept=0.0, weights=[0.0, 0.0, 0.0],
1026+
xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
1027+
nPoints=4, seed=0, eps=0.1)
1028+
self.assertEqual(len(linear_data), 4)
1029+
for point in linear_data:
1030+
self.assertEqual(len(point.features), 3)
1031+
1032+
linear_data = LinearDataGenerator.generateLinearRDD(
1033+
sc=sc, nexamples=6, nfeatures=2, eps=0.1,
1034+
nParts=2, intercept=0.0).collect()
1035+
self.assertEqual(len(linear_data), 6)
1036+
for point in linear_data:
1037+
self.assertEqual(len(point.features), 2)
1038+
1039+
10221040
if __name__ == "__main__":
10231041
if not _have_scipy:
10241042
print("NOTE: Skipping SciPy tests as it does not seem to be installed")

python/pyspark/mllib/util.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,41 @@ def load(cls, sc, path):
257257
return cls(java_model)
258258

259259

260+
class LinearDataGenerator(object):
261+
"""Utils for generating linear data"""
262+
263+
@staticmethod
264+
def generateLinearInput(intercept, weights, xMean, xVariance,
265+
nPoints, seed, eps):
266+
"""
267+
:param: intercept bias factor, the term c in X'w + c
268+
:param: weights feature vector, the term w in X'w + c
269+
:param: xMean Point around which the data X is centered.
270+
:param: xVariance Variance of the given data
271+
:param: nPoints Number of points to be generated
272+
:param: seed Random Seed
273+
:param: eps Used to scale the noise. If eps is set high,
274+
the amount of gaussian noise added is more.
275+
Returns a list of LabeledPoints of length nPoints
276+
"""
277+
weights = [float(weight) for weight in weights]
278+
xMean = [float(mean) for mean in xMean]
279+
xVariance = [float(var) for var in xVariance]
280+
return list(callMLlibFunc(
281+
"generateLinearInputWrapper", float(intercept), weights, xMean,
282+
xVariance, int(nPoints), int(seed), float(eps)))
283+
284+
@staticmethod
285+
def generateLinearRDD(sc, nexamples, nfeatures, eps,
286+
nParts=2, intercept=0.0):
287+
"""
288+
Generate a RDD of LabeledPoints.
289+
"""
290+
return callMLlibFunc(
291+
"generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures),
292+
float(eps), int(nParts), float(intercept))
293+
294+
260295
def _test():
261296
import doctest
262297
from pyspark.context import SparkContext

0 commit comments

Comments
 (0)