Skip to content

Commit 629a1ce

Browse files
SPARK-3278 added isotonic regression for weighted data. Added tests for Java api
1 parent 05d9048 commit 629a1ce

File tree

5 files changed

+345
-63
lines changed

5 files changed

+345
-63
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,33 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.apache.spark.mllib.linalg.Vector
20+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
21+
import org.apache.spark.mllib.regression.MonotonicityConstraint.Enum.MonotonicityConstraint
2122
import org.apache.spark.rdd.RDD
2223

23-
sealed trait MonotonicityConstraint {
24-
def holds(current: LabeledPoint, next: LabeledPoint): Boolean
25-
}
24+
object MonotonicityConstraint {
2625

27-
case object Isotonic extends MonotonicityConstraint {
28-
override def holds(current: LabeledPoint, next: LabeledPoint): Boolean = {
29-
current.label <= next.label
30-
}
31-
}
32-
case object Antitonic extends MonotonicityConstraint {
33-
override def holds(current: LabeledPoint, next: LabeledPoint): Boolean = {
34-
current.label >= next.label
26+
object Enum {
27+
28+
sealed trait MonotonicityConstraint {
29+
private[regression] def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean
30+
}
31+
32+
case object Isotonic extends MonotonicityConstraint {
33+
override def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean = {
34+
current.label <= next.label
35+
}
36+
}
37+
38+
case object Antitonic extends MonotonicityConstraint {
39+
override def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean = {
40+
current.label >= next.label
41+
}
42+
}
3543
}
44+
45+
val Isotonic = Enum.Isotonic
46+
val Antitonic = Enum.Antitonic
3647
}
3748

3849
/**
@@ -41,9 +52,10 @@ case object Antitonic extends MonotonicityConstraint {
4152
* @param predictions Weights computed for every feature.
4253
*/
4354
class IsotonicRegressionModel(
44-
val predictions: Seq[LabeledPoint],
55+
val predictions: Seq[WeightedLabeledPoint],
4556
val monotonicityConstraint: MonotonicityConstraint)
4657
extends RegressionModel {
58+
4759
override def predict(testData: RDD[Vector]): RDD[Double] =
4860
testData.map(predict)
4961

@@ -60,7 +72,7 @@ trait IsotonicRegressionAlgorithm
6072
extends Serializable {
6173

6274
protected def createModel(
63-
weights: Seq[LabeledPoint],
75+
weights: Seq[WeightedLabeledPoint],
6476
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
6577

6678
/**
@@ -70,47 +82,47 @@ trait IsotonicRegressionAlgorithm
7082
* @return model
7183
*/
7284
def run(
73-
input: RDD[LabeledPoint],
85+
input: RDD[WeightedLabeledPoint],
7486
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
7587

7688
/**
7789
* Run algorithm to obtain isotonic regression model
7890
* @param input data
79-
* @param initialWeights weights
8091
* @param monotonicityConstraint asc or desc
92+
* @param weights weights
8193
* @return
8294
*/
8395
def run(
84-
input: RDD[LabeledPoint],
85-
initialWeights: Vector,
86-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
96+
input: RDD[WeightedLabeledPoint],
97+
monotonicityConstraint: MonotonicityConstraint,
98+
weights: Vector): IsotonicRegressionModel
8799
}
88100

89101
class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
90102

91103
override def run(
92-
input: RDD[LabeledPoint],
104+
input: RDD[WeightedLabeledPoint],
93105
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
94106
createModel(
95-
parallelPoolAdjacentViolators(input, monotonicityConstraint),
107+
parallelPoolAdjacentViolators(input, monotonicityConstraint, Vectors.dense(Array(0d))),
96108
monotonicityConstraint)
97109
}
98110

99111
override def run(
100-
input: RDD[LabeledPoint],
101-
initialWeights: Vector,
102-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
103-
???
112+
input: RDD[WeightedLabeledPoint],
113+
monotonicityConstraint: MonotonicityConstraint,
114+
weights: Vector): IsotonicRegressionModel = {
115+
createModel(
116+
parallelPoolAdjacentViolators(input, monotonicityConstraint, weights),
117+
monotonicityConstraint)
104118
}
105119

106120
override protected def createModel(
107-
weights: Seq[LabeledPoint],
121+
predictions: Seq[WeightedLabeledPoint],
108122
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
109-
new IsotonicRegressionModel(weights, monotonicityConstraint)
123+
new IsotonicRegressionModel(predictions, monotonicityConstraint)
110124
}
111125

112-
113-
114126
/**
115127
* Performs a pool adjacent violators algorithm (PAVA)
116128
* Uses approach with single processing of data where violators in previously processed
@@ -123,18 +135,18 @@ class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
123135
* @return result
124136
*/
125137
private def poolAdjacentViolators(
126-
in: Array[LabeledPoint],
127-
monotonicityConstraint: MonotonicityConstraint): Array[LabeledPoint] = {
138+
in: Array[WeightedLabeledPoint],
139+
monotonicityConstraint: MonotonicityConstraint): Array[WeightedLabeledPoint] = {
128140

129141
//Pools sub array within given bounds assigning weighted average value to all elements
130-
def pool(in: Array[LabeledPoint], start: Int, end: Int): Unit = {
142+
def pool(in: Array[WeightedLabeledPoint], start: Int, end: Int): Unit = {
131143
val poolSubArray = in.slice(start, end + 1)
132144

133-
val weightedSum = poolSubArray.map(_.label).sum
134-
val weight = poolSubArray.length
145+
val weightedSum = poolSubArray.map(lp => lp.label * lp.weight).sum
146+
val weight = poolSubArray.map(_.weight).sum
135147

136148
for(i <- start to end) {
137-
in(i) = LabeledPoint(weightedSum / weight, in(i).features)
149+
in(i) = WeightedLabeledPoint(weightedSum / weight, in(i).features, in(i).weight)
138150
}
139151
}
140152

@@ -175,8 +187,9 @@ class PoolAdjacentViolators extends IsotonicRegressionAlgorithm {
175187
* @return result
176188
*/
177189
private def parallelPoolAdjacentViolators(
178-
testData: RDD[LabeledPoint],
179-
monotonicityConstraint: MonotonicityConstraint): Seq[LabeledPoint] = {
190+
testData: RDD[WeightedLabeledPoint],
191+
monotonicityConstraint: MonotonicityConstraint,
192+
weights: Vector): Seq[WeightedLabeledPoint] = {
180193

181194
poolAdjacentViolators(
182195
testData
@@ -200,14 +213,14 @@ object IsotonicRegression {
200213
*
201214
* @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
202215
* matrix A as well as the corresponding right hand side label y
203-
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
216+
* @param weights Initial set of weights to be used. Array should be equal in size to
204217
* the number of features in the data.
205218
*/
206219
def train(
207-
input: RDD[LabeledPoint],
208-
initialWeights: Vector,
209-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
210-
new PoolAdjacentViolators().run(input, initialWeights, monotonicityConstraint)
220+
input: RDD[WeightedLabeledPoint],
221+
monotonicityConstraint: MonotonicityConstraint,
222+
weights: Vector): IsotonicRegressionModel = {
223+
new PoolAdjacentViolators().run(input, monotonicityConstraint, weights)
211224
}
212225

213226
/**
@@ -219,7 +232,7 @@ object IsotonicRegression {
219232
* matrix A as well as the corresponding right hand side label y
220233
*/
221234
def train(
222-
input: RDD[LabeledPoint],
235+
input: RDD[WeightedLabeledPoint],
223236
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
224237
new PoolAdjacentViolators().run(input, monotonicityConstraint)
225238
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.regression
19+
20+
import org.apache.spark.mllib.linalg.Vector
21+
import org.apache.spark.rdd.RDD
22+
23+
import scala.beans.BeanInfo
24+
25+
object WeightedLabeledPointConversions {
26+
implicit def labeledPointToWeightedLabeledPoint(
27+
labeledPoint: LabeledPoint): WeightedLabeledPoint = {
28+
WeightedLabeledPoint(labeledPoint.label, labeledPoint.features, 1)
29+
}
30+
31+
implicit def labeledPointRDDToWeightedLabeledPointRDD(
32+
rdd: RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = {
33+
rdd.map(lp => WeightedLabeledPoint(lp.label, lp.features, 1))
34+
}
35+
}
36+
37+
/**
38+
* Labeled point with weight
39+
*/
40+
@BeanInfo
41+
case class WeightedLabeledPoint(label: Double, features: Vector, weight: Double)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.util
19+
20+
import org.apache.spark.mllib.linalg.Vectors
21+
import org.apache.spark.mllib.regression.WeightedLabeledPointConversions._
22+
import org.apache.spark.mllib.regression.{LabeledPoint, WeightedLabeledPoint}
23+
24+
import scala.collection.JavaConversions._
25+
26+
object IsotonicDataGenerator {
27+
28+
/**
29+
* Return a Java List of ordered labeled points
30+
* @param labels list of labels for the data points
31+
* @return Java List of input.
32+
*/
33+
def generateIsotonicInputAsList(labels: Array[Double]): java.util.List[WeightedLabeledPoint] = {
34+
seqAsJavaList(generateIsotonicInput(wrapDoubleArray(labels):_*))
35+
}
36+
37+
/**
38+
* Return an ordered sequence of labeled data points with default weights
39+
* @param labels list of labels for the data points
40+
* @return sequence of data points
41+
*/
42+
def generateIsotonicInput(labels: Double*): Seq[WeightedLabeledPoint] = {
43+
labels.zip(1 to labels.size)
44+
.map(point => labeledPointToWeightedLabeledPoint(LabeledPoint(point._1, Vectors.dense(point._2))))
45+
}
46+
47+
/**
48+
* Return an ordered sequence of labeled weighted data points
49+
* @param labels list of labels for the data points
50+
* @param weights list of weights for the data points
51+
* @return sequence of data points
52+
*/
53+
def generateWeightedIsotonicInput(labels: Seq[Double], weights: Seq[Double]): Seq[WeightedLabeledPoint] = {
54+
labels.zip(1 to labels.size).zip(weights)
55+
.map(point => WeightedLabeledPoint(point._1._1, Vectors.dense(point._1._2), point._2))
56+
}
57+
}

0 commit comments

Comments
 (0)