Skip to content

Commit 5a54ea4

Browse files
Merge pull request #2 from mengxr/isotonic-fix-java
fix java tests
2 parents e3c0e44 + 37ba24e commit 5a54ea4

File tree

2 files changed

+12
-28
lines changed

2 files changed

+12
-28
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class IsotonicRegressionModel (
7575
* @return Predicted labels.
7676
*/
7777
def predict(testData: JavaDoubleRDD): JavaDoubleRDD = {
78-
JavaDoubleRDD.fromRDD(predict(testData.rdd.asInstanceOf[RDD[Double]]))
78+
JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]]))
7979
}
8080

8181
/**
@@ -194,7 +194,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
194194
* @return Isotonic regression model.
195195
*/
196196
def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = {
197-
run(input.rdd.asInstanceOf[RDD[(Double, Double, Double)]])
197+
run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]])
198198
}
199199

200200
/**

mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,48 +18,36 @@
1818
package org.apache.spark.mllib.regression;
1919

2020
import java.io.Serializable;
21-
import java.util.ArrayList;
22-
import java.util.Arrays;
2321
import java.util.List;
2422

25-
import org.apache.spark.api.java.JavaDoubleRDD;
2623
import scala.Tuple3;
2724

25+
import com.google.common.collect.Lists;
2826
import org.junit.After;
2927
import org.junit.Assert;
3028
import org.junit.Before;
3129
import org.junit.Test;
3230

31+
import org.apache.spark.api.java.JavaDoubleRDD;
3332
import org.apache.spark.api.java.JavaRDD;
3433
import org.apache.spark.api.java.JavaSparkContext;
3534

3635
public class JavaIsotonicRegressionSuite implements Serializable {
3736
private transient JavaSparkContext sc;
3837

3938
private List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
40-
List<Tuple3<Double, Double, Double>> input = new ArrayList<>();
39+
List<Tuple3<Double, Double, Double>> input = Lists.newArrayList();
4140

42-
for(int i = 1; i <= labels.length; i++) {
43-
input.add(new Tuple3(labels[i-1], (double)i, 1d));
41+
for (int i = 1; i <= labels.length; i++) {
42+
input.add(new Tuple3<Double, Double, Double>(labels[i-1], (double) i, 1d));
4443
}
4544

4645
return input;
4746
}
4847

49-
private double difference(List<Tuple3<Double, Double, Double>> expected, IsotonicRegressionModel model) {
50-
double diff = 0;
51-
52-
for(int i = 0; i < model.predictions().length; i++) {
53-
Tuple3<Double, Double, Double> exp = expected.get(i);
54-
diff += Math.abs(model.predict(exp._2()) - exp._1());
55-
}
56-
57-
return diff;
58-
}
59-
6048
private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
6149
JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
62-
sc.parallelize(generateIsotonicInput(labels)).cache();
50+
sc.parallelize(generateIsotonicInput(labels), 2).cache();
6351

6452
return new IsotonicRegression().run(trainRDD);
6553
}
@@ -80,20 +68,16 @@ public void testIsotonicRegressionJavaRDD() {
8068
IsotonicRegressionModel model =
8169
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
8270

83-
List<Tuple3<Double, Double, Double>> expected =
84-
generateIsotonicInput(new double[] {1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12});
85-
86-
Assert.assertTrue(difference(expected, model) == 0);
71+
Assert.assertArrayEquals(
72+
new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14);
8773
}
8874

8975
@Test
9076
public void testIsotonicRegressionPredictionsJavaRDD() {
9177
IsotonicRegressionModel model =
9278
runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
9379

94-
JavaDoubleRDD testRDD =
95-
sc.parallelizeDoubles(Arrays.asList(new Double[] {0.0, 1.0, 9.5, 12.0, 13.0}));
96-
80+
JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0));
9781
List<Double> predictions = model.predict(testRDD).collect();
9882

9983
Assert.assertTrue(predictions.get(0) == 1d);
@@ -102,4 +86,4 @@ public void testIsotonicRegressionPredictionsJavaRDD() {
10286
Assert.assertTrue(predictions.get(3) == 12d);
10387
Assert.assertTrue(predictions.get(4) == 12d);
10488
}
105-
}
89+
}

0 commit comments

Comments
 (0)