Skip to content

Commit 1ef26e0

Browse files
committed
specialize methods/types for Java
1 parent df293ed commit 1ef26e0

File tree

9 files changed

+163
-11
lines changed

9 files changed

+163
-11
lines changed

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.ml
2020
import org.apache.spark.ml.param.{ParamMap, Params, ParamPair}
2121
import org.apache.spark.sql.SchemaRDD
2222

23+
import scala.annotation.varargs
24+
2325
/**
2426
* Abstract class for estimators that fits models to data.
2527
*/
@@ -52,7 +54,8 @@ abstract class Estimator[M <: Model] extends Identifiable with Params with Pipel
5254
* @param otherParamPairs other parameters
5355
* @return fitted model
5456
*/
55-
def fit(
57+
@varargs
58+
def fit[T](
5659
dataset: SchemaRDD,
5760
firstParamPair: ParamPair[_],
5861
otherParamPairs: ParamPair[_]*): M = {

mllib/src/main/scala/org/apache/spark/ml/example/BinaryClassificationEvaluator.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.ml.example
1919

2020
import org.apache.spark.ml._
21-
import org.apache.spark.ml.api.param.HasMetricName
2221
import org.apache.spark.ml.param._
2322
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
2423
import org.apache.spark.sql.SchemaRDD

mllib/src/main/scala/org/apache/spark/ml/example/CrossValidator.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with Params
7373

7474
private val f2jBLAS = new F2jBLAS
7575

76+
// Overwrite return type for Java users.
77+
override def setEstimator(estimator: Estimator[_]): this.type = super.setEstimator(estimator)
78+
override def setEstimatorParamMaps(estimatorParamMaps: Array[ParamMap]): this.type =
79+
super.setEstimatorParamMaps(estimatorParamMaps)
80+
override def setEvaluator(evaluator: Evaluator): this.type = super.setEvaluator(evaluator)
81+
7682
val numFolds: Param[Int] = new Param(this, "numFolds", "number of folds for cross validation", 3)
7783

7884
def setNumFolds(numFolds: Int): this.type = {

mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.ml.example
1919

2020
import org.apache.spark.ml._
21-
import org.apache.spark.ml.api.param._
2221
import org.apache.spark.ml.param._
2322
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
2423
import org.apache.spark.mllib.linalg.{BLAS, Vector}
@@ -37,6 +36,12 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
3736
setRegParam(0.1)
3837
setMaxIter(100)
3938

39+
// Overwrite the return type of setters for Java users.
40+
override def setRegParam(regParam: Double): this.type = super.setRegParam(regParam)
41+
override def setMaxIter(maxIter: Int): this.type = super.setMaxIter(maxIter)
42+
override def setLabelCol(labelCol: String): this.type = super.setLabelCol(labelCol)
43+
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)
44+
4045
override final val model: LogisticRegressionModelParams = new LogisticRegressionModelParams {}
4146

4247
override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {

mllib/src/main/scala/org/apache/spark/ml/example/StandardScaler.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
package org.apache.spark.ml.example
1919

2020
import org.apache.spark.ml._
21-
import org.apache.spark.ml.api.param.HasOutputCol
22-
import org.apache.spark.ml.param.{ParamMap, Params, HasOutputCol, HasInputCol}
21+
import org.apache.spark.ml.param._
2322
import org.apache.spark.mllib.feature
2423
import org.apache.spark.mllib.linalg.Vector
2524
import org.apache.spark.sql.SchemaRDD
@@ -29,6 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.Row
2928

3029
class StandardScaler extends Transformer with Params with HasInputCol with HasOutputCol {
3130

31+
override def setInputCol(inputCol: String): this.type = super.setInputCol(inputCol)
32+
override def setOutputCol(outputCol: String): this.type = super.setOutputCol(outputCol)
33+
3234
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
3335
import dataset.sqlContext._
3436
val map = this.paramMap ++ paramMap

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import java.lang.reflect.Modifier
2222
import scala.collection.mutable
2323
import scala.language.implicitConversions
2424

25+
import org.apache.spark.ml.Identifiable
26+
2527
/**
2628
* A param with self-contained documentation and optionally default value.
2729
*
@@ -30,7 +32,7 @@ import scala.language.implicitConversions
3032
* @param doc documentation
3133
* @tparam T param value type
3234
*/
33-
class Param[T] private (
35+
class Param[T] private[param] (
3436
val parent: Params,
3537
val name: String,
3638
val doc: String,
@@ -75,6 +77,16 @@ class Param[T] private (
7577
}
7678
}
7779

80+
class DoubleParam(parent: Params, name: String, doc: String, default: Option[Double] = None)
81+
extends Param[Double](parent, name, doc, default) {
82+
override def w(value: Double): ParamPair[Double] = ParamPair(this, value)
83+
}
84+
85+
class IntParam(parent: Params, name: String, doc: String, default: Option[Int] = None)
86+
extends Param[Int](parent, name, doc, default) {
87+
override def w(value: Int): ParamPair[Int] = ParamPair(this, value)
88+
}
89+
7890
/**
7991
* A param amd its value.
8092
*/
@@ -199,7 +211,7 @@ class ParamMap private[ml] (
199211
/**
200212
* Filter this param map for the given parent.
201213
*/
202-
def filter(parent: Identifiable): ParamMap = {
214+
def filter(parent: Params): ParamMap = {
203215
val map = params.filterKeys(_.parent == parent)
204216
new ParamMap(map.asInstanceOf[mutable.Map[Param[Any], Any]])
205217
}
@@ -260,6 +272,22 @@ class ParamGridBuilder {
260272
this
261273
}
262274

275+
/**
276+
* Specialize for Java users.
277+
*/
278+
def addMulti(param: DoubleParam, values: Array[Double]): this.type = {
279+
paramGrid.put(param, values)
280+
this
281+
}
282+
283+
/**
284+
* Specialize for Java users.
285+
*/
286+
def addMulti(param: IntParam, values: Array[Int]): this.type = {
287+
paramGrid.put(param, values)
288+
this
289+
}
290+
263291
def build(): Array[ParamMap] = {
264292
var paramSets = Array(new ParamMap)
265293
paramGrid.foreach { case (param, values) =>

mllib/src/main/scala/org/apache/spark/ml/param/shared.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.ml.param
1919

20-
import org.apache.spark.ml.Params
21-
2220
trait HasRegParam extends Params {
2321

24-
val regParam: Param[Double] = new Param(this, "regParam", "regularization parameter")
22+
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
2523

2624
def setRegParam(regParam: Double): this.type = {
2725
set(this.regParam, regParam)
@@ -35,7 +33,7 @@ trait HasRegParam extends Params {
3533

3634
trait HasMaxIter extends Params {
3735

38-
val maxIter: Param[Int] = new Param(this, "maxIter", "max number of iterations")
36+
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
3937

4038
def setMaxIter(maxIter: Int): this.type = {
4139
set(this.maxIter, maxIter)

mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20+
import scala.beans.BeanInfo
21+
2022
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2123
import org.apache.spark.mllib.util.NumericParser
2224
import org.apache.spark.SparkException
@@ -27,6 +29,7 @@ import org.apache.spark.SparkException
2729
* @param label Label for this data point.
2830
* @param features List of features for this data point.
2931
*/
32+
@BeanInfo
3033
case class LabeledPoint(label: Double, features: Vector) {
3134
override def toString: String = {
3235
"(%s,%s)".format(label, features)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.example;
19+
20+
import java.io.Serializable;
21+
22+
import org.apache.spark.api.java.JavaRDD;
23+
import org.apache.spark.api.java.JavaSparkContext;
24+
import org.apache.spark.ml.Pipeline;
25+
import org.apache.spark.ml.PipelineModel;
26+
import org.apache.spark.ml.PipelineStage;
27+
import org.apache.spark.ml.param.ParamGridBuilder;
28+
import org.apache.spark.ml.param.ParamMap;
29+
import org.apache.spark.mllib.regression.LabeledPoint;
30+
import org.apache.spark.mllib.util.MLUtils;
31+
import org.apache.spark.sql.api.java.JavaSQLContext;
32+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
33+
import org.apache.spark.sql.api.java.Row;
34+
35+
import org.junit.After;
36+
import org.junit.Before;
37+
import org.junit.Test;
38+
39+
public class JavaLogisticRegressionSuite implements Serializable {
40+
41+
private transient JavaSparkContext jsc;
42+
private transient JavaSQLContext jsql;
43+
private transient JavaSchemaRDD dataset;
44+
45+
@Before
46+
public void setUp() {
47+
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
48+
jsql = new JavaSQLContext(jsc);
49+
JavaRDD<LabeledPoint> points =
50+
MLUtils.loadLibSVMFile(jsc.sc(), "../data/mllib/sample_binary_classification_data.txt")
51+
.toJavaRDD();
52+
dataset = jsql.applySchema(points, LabeledPoint.class);
53+
}
54+
55+
@After
56+
public void tearDown() {
57+
jsc.stop();
58+
jsc = null;
59+
}
60+
61+
@Test
62+
public void logisticRegression() {
63+
LogisticRegression lr = new LogisticRegression()
64+
.setMaxIter(10)
65+
.setRegParam(1.0);
66+
lr.model().setThreshold(0.8);
67+
// In Java we can access baseSchemaRDD, while in Scala we cannot.
68+
LogisticRegressionModel model = lr.fit(dataset.baseSchemaRDD());
69+
model.transform(dataset.baseSchemaRDD()).registerTempTable("prediction");
70+
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
71+
for (Row r: predictions.collect()) {
72+
System.out.println(r);
73+
}
74+
}
75+
76+
@Test
77+
public void logisticRegressionWithCrossValidation() {
78+
LogisticRegression lr = new LogisticRegression();
79+
ParamMap[] lrParamMaps = new ParamGridBuilder()
80+
.addMulti(lr.regParam(), new double[] {0.1, 100.0})
81+
.addMulti(lr.maxIter(), new int[] {0, 5})
82+
.build();
83+
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
84+
CrossValidator cv = new CrossValidator()
85+
.setEstimator(lr)
86+
.setEstimatorParamMaps(lrParamMaps)
87+
.setEvaluator(eval)
88+
.setNumFolds(3);
89+
CrossValidatorModel bestModel = cv.fit(dataset.baseSchemaRDD());
90+
}
91+
92+
@Test
93+
public void logisticRegressionWithPipeline() {
94+
StandardScaler scaler = new StandardScaler()
95+
.setInputCol("features")
96+
.setOutputCol("scaledFeatures");
97+
LogisticRegression lr = new LogisticRegression()
98+
.setFeaturesCol("scaledFeatures");
99+
Pipeline pipeline = new Pipeline()
100+
.setStages(new PipelineStage[] {scaler, lr});
101+
PipelineModel model = pipeline.fit(dataset.baseSchemaRDD());
102+
model.transform(dataset.baseSchemaRDD()).registerTempTable("prediction");
103+
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
104+
for (Row r: predictions.collect()) {
105+
System.out.println(r);
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)