Skip to content

Commit 089bf86

Browse files
Removed MonotonicityConstraint, Isotonic and Antitonic constraints. Replced by simple boolean
1 parent c06f88c commit 089bf86

File tree

3 files changed

+48
-82
lines changed

3 files changed

+48
-82
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,17 @@
1818
package org.apache.spark.mllib.regression
1919

2020
import org.apache.spark.mllib.linalg.Vector
21-
import org.apache.spark.mllib.regression.MonotonicityConstraint.MonotonicityConstraint._
2221
import org.apache.spark.rdd.RDD
2322

24-
/**
25-
* Monotonicity constrains for monotone regression
26-
* Isotonic (increasing)
27-
* Antitonic (decreasing)
28-
*/
29-
object MonotonicityConstraint {
30-
31-
object MonotonicityConstraint {
32-
33-
sealed trait MonotonicityConstraint {
34-
private[regression] def holds(
35-
current: WeightedLabeledPoint,
36-
next: WeightedLabeledPoint): Boolean
37-
}
38-
39-
/**
40-
* Isotonic monotonicity constraint. Increasing sequence
41-
*/
42-
case object Isotonic extends MonotonicityConstraint {
43-
override def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean = {
44-
current.label <= next.label
45-
}
46-
}
47-
48-
/**
49-
* Antitonic monotonicity constrain. Decreasing sequence
50-
*/
51-
case object Antitonic extends MonotonicityConstraint {
52-
override def holds(current: WeightedLabeledPoint, next: WeightedLabeledPoint): Boolean = {
53-
current.label >= next.label
54-
}
55-
}
56-
}
57-
58-
val Isotonic = MonotonicityConstraint.Isotonic
59-
val Antitonic = MonotonicityConstraint.Antitonic
60-
}
61-
6223
/**
6324
* Regression model for Isotonic regression
6425
*
6526
* @param predictions Weights computed for every feature.
66-
* @param monotonicityConstraint specifies if the sequence is increasing or decreasing
27+
* @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
6728
*/
6829
class IsotonicRegressionModel(
6930
val predictions: Seq[WeightedLabeledPoint],
70-
val monotonicityConstraint: MonotonicityConstraint)
31+
val isotonic: Boolean)
7132
extends RegressionModel {
7233

7334
override def predict(testData: RDD[Vector]): RDD[Double] =
@@ -91,23 +52,23 @@ trait IsotonicRegressionAlgorithm
9152
*
9253
* @param predictions labels estimated using isotonic regression algorithm.
9354
* Used for predictions on new data points.
94-
* @param monotonicityConstraint isotonic or antitonic
55+
* @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
9556
* @return isotonic regression model
9657
*/
9758
protected def createModel(
9859
predictions: Seq[WeightedLabeledPoint],
99-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
60+
isotonic: Boolean): IsotonicRegressionModel
10061

10162
/**
10263
* Run algorithm to obtain isotonic regression model
10364
*
10465
* @param input data
105-
* @param monotonicityConstraint ascending or descenting
66+
* @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
10667
* @return isotonic regression model
10768
*/
10869
def run(
10970
input: RDD[WeightedLabeledPoint],
110-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
71+
isotonic: Boolean): IsotonicRegressionModel
11172
}
11273

11374
/**
@@ -118,16 +79,16 @@ class PoolAdjacentViolators private [mllib]
11879

11980
override def run(
12081
input: RDD[WeightedLabeledPoint],
121-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
82+
isotonic: Boolean): IsotonicRegressionModel = {
12283
createModel(
123-
parallelPoolAdjacentViolators(input, monotonicityConstraint),
124-
monotonicityConstraint)
84+
parallelPoolAdjacentViolators(input, isotonic),
85+
isotonic)
12586
}
12687

12788
override protected def createModel(
12889
predictions: Seq[WeightedLabeledPoint],
129-
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
130-
new IsotonicRegressionModel(predictions, monotonicityConstraint)
90+
isotonic: Boolean): IsotonicRegressionModel = {
91+
new IsotonicRegressionModel(predictions, isotonic)
13192
}
13293

13394
/**
@@ -138,12 +99,12 @@ class PoolAdjacentViolators private [mllib]
13899
* Method in situ mutates input array
139100
*
140101
* @param in input data
141-
* @param monotonicityConstraint asc or desc
102+
* @param isotonic asc or desc
142103
* @return result
143104
*/
144105
private def poolAdjacentViolators(
145106
in: Array[WeightedLabeledPoint],
146-
monotonicityConstraint: MonotonicityConstraint): Array[WeightedLabeledPoint] = {
107+
isotonic: Boolean): Array[WeightedLabeledPoint] = {
147108

148109
// Pools sub array within given bounds assigning weighted average value to all elements
149110
def pool(in: Array[WeightedLabeledPoint], start: Int, end: Int): Unit = {
@@ -159,11 +120,17 @@ class PoolAdjacentViolators private [mllib]
159120

160121
var i = 0
161122

123+
val monotonicityConstrainter: (Double, Double) => Boolean = (x, y) => if(isotonic) {
124+
x <= y
125+
} else {
126+
x >= y
127+
}
128+
162129
while(i < in.length) {
163130
var j = i
164131

165132
// Find monotonicity violating sequence, if any
166-
while(j < in.length - 1 && !monotonicityConstraint.holds(in(j), in(j + 1))) {
133+
while(j < in.length - 1 && !monotonicityConstrainter(in(j).label, in(j + 1).label)) {
167134
j = j + 1
168135
}
169136

@@ -173,7 +140,7 @@ class PoolAdjacentViolators private [mllib]
173140
} else {
174141
// Otherwise pool the violating sequence
175142
// And check if pooling caused monotonicity violation in previously processed points
176-
while (i >= 0 && !monotonicityConstraint.holds(in(i), in(i + 1))) {
143+
while (i >= 0 && !monotonicityConstrainter(in(i).label, in(i + 1).label)) {
177144
pool(in, i, j)
178145
i = i - 1
179146
}
@@ -190,19 +157,19 @@ class PoolAdjacentViolators private [mllib]
190157
* Calls Pool adjacent violators on each partition and then again on the result
191158
*
192159
* @param testData input
193-
* @param monotonicityConstraint asc or desc
160+
* @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
194161
* @return result
195162
*/
196163
private def parallelPoolAdjacentViolators(
197164
testData: RDD[WeightedLabeledPoint],
198-
monotonicityConstraint: MonotonicityConstraint): Seq[WeightedLabeledPoint] = {
165+
isotonic: Boolean): Seq[WeightedLabeledPoint] = {
199166

200167
poolAdjacentViolators(
201168
testData
202169
.sortBy(_.features.toArray.head)
203170
.cache()
204-
.mapPartitions(it => poolAdjacentViolators(it.toArray, monotonicityConstraint).toIterator)
205-
.collect(), monotonicityConstraint)
171+
.mapPartitions(it => poolAdjacentViolators(it.toArray, isotonic).toIterator)
172+
.collect(), isotonic)
206173
}
207174
}
208175

@@ -221,11 +188,11 @@ object IsotonicRegression {
221188
* Each point describes a row of the data
222189
* matrix A as well as the corresponding right hand side label y
223190
* and weight as number of measurements
224-
* @param monotonicityConstraint Isotonic (increasing) or Antitonic (decreasing) sequence
191+
* @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
225192
*/
226193
def train(
227194
input: RDD[WeightedLabeledPoint],
228-
monotonicityConstraint: MonotonicityConstraint = Isotonic): IsotonicRegressionModel = {
229-
new PoolAdjacentViolators().run(input, monotonicityConstraint)
195+
isotonic: Boolean = true): IsotonicRegressionModel = {
196+
new PoolAdjacentViolators().run(input, isotonic)
230197
}
231198
}

mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public void runIsotonicRegressionUsingConstructor() {
6262
new double[] {1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12})).cache();
6363

6464
IsotonicRegressionAlgorithm isotonicRegressionAlgorithm = new PoolAdjacentViolators();
65-
IsotonicRegressionModel model = isotonicRegressionAlgorithm.run(testRDD.rdd(), MonotonicityConstraint.Isotonic());
65+
IsotonicRegressionModel model = isotonicRegressionAlgorithm.run(testRDD.rdd(), true);
6666

6767
List<WeightedLabeledPoint> expected = IsotonicDataGenerator
6868
.generateIsotonicInputAsList(
@@ -77,7 +77,7 @@ public void runIsotonicRegressionUsingStaticMethod() {
7777
.generateIsotonicInputAsList(
7878
new double[] {1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12})).cache();
7979

80-
IsotonicRegressionModel model = IsotonicRegression.train(testRDD.rdd(), MonotonicityConstraint.Isotonic());
80+
IsotonicRegressionModel model = IsotonicRegression.train(testRDD.rdd(), true);
8181

8282
List<WeightedLabeledPoint> expected = IsotonicDataGenerator
8383
.generateIsotonicInputAsList(
@@ -92,7 +92,7 @@ public void testPredictJavaRDD() {
9292
.generateIsotonicInputAsList(
9393
new double[] {1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12})).cache();
9494

95-
IsotonicRegressionModel model = IsotonicRegression.train(testRDD.rdd(), MonotonicityConstraint.Isotonic());
95+
IsotonicRegressionModel model = IsotonicRegression.train(testRDD.rdd(), true);
9696

9797
JavaRDD<Vector> vectors = testRDD.map(new Function<WeightedLabeledPoint, Vector>() {
9898
@Override

mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.mllib.regression
1919

2020
import org.apache.spark.mllib.linalg.Vectors
21-
import org.apache.spark.mllib.regression.MonotonicityConstraint.MonotonicityConstraint.{Antitonic, Isotonic}
2221
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
2322
import org.scalatest.{Matchers, FunSuite}
2423
import WeightedLabeledPointConversions._
@@ -37,7 +36,7 @@ class IsotonicRegressionSuite
3736
val testRDD = sc.parallelize(generateIsotonicInput(1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12, 14, 15, 17, 16, 17, 18, 19, 20)).cache()
3837

3938
val alg = new PoolAdjacentViolators
40-
val model = alg.run(testRDD, Isotonic)
39+
val model = alg.run(testRDD, true)
4140

4241
model.predictions should be(generateIsotonicInput(1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12, 14, 15, 16.5, 16.5, 17, 18, 19, 20))
4342
}
@@ -46,7 +45,7 @@ class IsotonicRegressionSuite
4645
val testRDD = sc.parallelize(List[WeightedLabeledPoint]()).cache()
4746

4847
val alg = new PoolAdjacentViolators
49-
val model = alg.run(testRDD, Isotonic)
48+
val model = alg.run(testRDD, true)
5049

5150
model.predictions should be(List())
5251
}
@@ -55,7 +54,7 @@ class IsotonicRegressionSuite
5554
val testRDD = sc.parallelize(generateIsotonicInput(1)).cache()
5655

5756
val alg = new PoolAdjacentViolators
58-
val model = alg.run(testRDD, Isotonic)
57+
val model = alg.run(testRDD, true)
5958

6059
model.predictions should be(generateIsotonicInput(1))
6160
}
@@ -64,7 +63,7 @@ class IsotonicRegressionSuite
6463
val testRDD = sc.parallelize(generateIsotonicInput(1, 2, 3, 4, 5)).cache()
6564

6665
val alg = new PoolAdjacentViolators
67-
val model = alg.run(testRDD, Isotonic)
66+
val model = alg.run(testRDD, true)
6867

6968
model.predictions should be(generateIsotonicInput(1, 2, 3, 4, 5))
7069
}
@@ -73,7 +72,7 @@ class IsotonicRegressionSuite
7372
val testRDD = sc.parallelize(generateIsotonicInput(5, 4, 3, 2, 1)).cache()
7473

7574
val alg = new PoolAdjacentViolators
76-
val model = alg.run(testRDD, Isotonic)
75+
val model = alg.run(testRDD, true)
7776

7877
model.predictions should be(generateIsotonicInput(3, 3, 3, 3, 3))
7978
}
@@ -82,7 +81,7 @@ class IsotonicRegressionSuite
8281
val testRDD = sc.parallelize(generateIsotonicInput(1, 2, 3, 4, 2)).cache()
8382

8483
val alg = new PoolAdjacentViolators
85-
val model = alg.run(testRDD, Isotonic)
84+
val model = alg.run(testRDD, true)
8685

8786
model.predictions should be(generateIsotonicInput(1, 2, 3, 3, 3))
8887
}
@@ -91,7 +90,7 @@ class IsotonicRegressionSuite
9190
val testRDD = sc.parallelize(generateIsotonicInput(4, 2, 3, 4, 5)).cache()
9291

9392
val alg = new PoolAdjacentViolators
94-
val model = alg.run(testRDD, Isotonic)
93+
val model = alg.run(testRDD, true)
9594

9695
model.predictions should be(generateIsotonicInput(3, 3, 3, 4, 5))
9796
}
@@ -100,7 +99,7 @@ class IsotonicRegressionSuite
10099
val testRDD = sc.parallelize(generateIsotonicInput(-1, -2, 0, 1, -1)).cache()
101100

102101
val alg = new PoolAdjacentViolators
103-
val model = alg.run(testRDD, Isotonic)
102+
val model = alg.run(testRDD, true)
104103

105104
model.predictions should be(generateIsotonicInput(-1.5, -1.5, 0, 0, 0))
106105
}
@@ -109,7 +108,7 @@ class IsotonicRegressionSuite
109108
val testRDD = sc.parallelize(generateIsotonicInput(1, 2, 3, 4, 5).reverse).cache()
110109

111110
val alg = new PoolAdjacentViolators
112-
val model = alg.run(testRDD, Isotonic)
111+
val model = alg.run(testRDD, true)
113112

114113
model.predictions should be(generateIsotonicInput(1, 2, 3, 4, 5))
115114
}
@@ -118,7 +117,7 @@ class IsotonicRegressionSuite
118117
val testRDD = sc.parallelize(generateWeightedIsotonicInput(Seq(1, 2, 3, 4, 2), Seq(1, 1, 1, 1, 2))).cache()
119118

120119
val alg = new PoolAdjacentViolators
121-
val model = alg.run(testRDD, Isotonic)
120+
val model = alg.run(testRDD, true)
122121

123122
model.predictions should be(generateWeightedIsotonicInput(Seq(1, 2, 2.75, 2.75,2.75), Seq(1, 1, 1, 1, 2)))
124123
}
@@ -127,7 +126,7 @@ class IsotonicRegressionSuite
127126
val testRDD = sc.parallelize(generateWeightedIsotonicInput(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1))).cache()
128127

129128
val alg = new PoolAdjacentViolators
130-
val model = alg.run(testRDD, Isotonic)
129+
val model = alg.run(testRDD, true)
131130

132131
model.predictions.map(p => p.copy(label = round(p.label))) should be
133132
(generateWeightedIsotonicInput(Seq(1, 2, 3.3/1.2, 3.3/1.2, 3.3/1.2), Seq(1, 1, 1, 0.1, 0.1)))
@@ -137,7 +136,7 @@ class IsotonicRegressionSuite
137136
val testRDD = sc.parallelize(generateWeightedIsotonicInput(Seq(1, 2, 3, 2, 1), Seq(-1, 1, -3, 1, -5))).cache()
138137

139138
val alg = new PoolAdjacentViolators
140-
val model = alg.run(testRDD, Isotonic)
139+
val model = alg.run(testRDD, true)
141140

142141
model.predictions.map(p => p.copy(label = round(p.label))) should be
143142
(generateWeightedIsotonicInput(Seq(1, 10/6, 10/6, 10/6, 10/6), Seq(-1, 1, -3, 1, -5)))
@@ -147,7 +146,7 @@ class IsotonicRegressionSuite
147146
val testRDD = sc.parallelize(generateWeightedIsotonicInput(Seq(1, 2, 3, 2, 1), Seq(0, 0, 0, 1, 0))).cache()
148147

149148
val alg = new PoolAdjacentViolators
150-
val model = alg.run(testRDD, Isotonic)
149+
val model = alg.run(testRDD, true)
151150

152151
model.predictions should be(generateWeightedIsotonicInput(Seq(1, 2, 2, 2, 2), Seq(0, 0, 0, 1, 0)))
153152
}
@@ -156,7 +155,7 @@ class IsotonicRegressionSuite
156155
val testRDD = sc.parallelize(generateIsotonicInput(1, 2, 7, 1, 2)).cache()
157156

158157
val alg = new PoolAdjacentViolators
159-
val model = alg.run(testRDD, Isotonic)
158+
val model = alg.run(testRDD, true)
160159

161160
model.predict(Vectors.dense(0)) should be(1)
162161
model.predict(Vectors.dense(2)) should be(2)
@@ -168,7 +167,7 @@ class IsotonicRegressionSuite
168167
val testRDD = sc.parallelize(generateIsotonicInput(7, 5, 3, 5, 1)).cache()
169168

170169
val alg = new PoolAdjacentViolators
171-
val model = alg.run(testRDD, Antitonic)
170+
val model = alg.run(testRDD, false)
172171

173172
model.predict(Vectors.dense(0)) should be(7)
174173
model.predict(Vectors.dense(2)) should be(5)
@@ -183,7 +182,7 @@ class IsotonicRegressionSuite
183182
LabeledPoint(1, Vectors.dense(2)))).cache()
184183

185184
val alg = new PoolAdjacentViolators
186-
val model = alg.run(testRDD, Isotonic)
185+
val model = alg.run(testRDD, true)
187186

188187
model.predictions should be(generateIsotonicInput(1.5, 1.5))
189188
}
@@ -201,7 +200,7 @@ class IsotonicRegressionClusterSuite extends FunSuite with LocalClusterSparkCont
201200

202201
// If we serialize data directly in the task closure, the size of the serialized task would be
203202
// greater than 1MB and hence Spark would throw an error.
204-
val model = IsotonicRegression.train(points, Isotonic)
203+
val model = IsotonicRegression.train(points, true)
205204
val predictions = model.predict(points.map(_.features))
206205
}
207206
}

0 commit comments

Comments
 (0)