Skip to content

Commit e8741a7

Browse files
committed
CR feedback
1 parent b78804e commit e8741a7

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
2222

2323
import org.apache.spark.SparkContext
2424
import org.apache.spark.rdd.RDD
25-
import scala.reflect._
25+
import scala.reflect.ClassTag
2626

2727
import org.apache.spark.SparkContext
2828
import org.apache.spark.rdd.RDD
@@ -181,21 +181,17 @@ object MLUtils {
181181
dataStr.saveAsTextFile(dir)
182182
}
183183

184-
def meanSquaredError(a: Double, b: Double): Double = {
185-
(a-b)*(a-b)
186-
}
187-
188184
/**
189185
* Return a k element list of pairs of RDDs with the first element of each pair
190-
* containing a unique 1/Kth of the data and the second element contain the composite of that.
186+
* containing a unique 1/Kth of the data and the second element contain the compliment of that.
191187
*/
192-
def kFoldRdds[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
188+
def kFold[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
193189
val foldsF = folds.toFloat
194190
1.to(folds).map(fold => ((
195-
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, false),
196-
seed),
197-
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, true),
198-
seed)
191+
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF,
192+
complement = false), seed),
193+
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF,
194+
complement = true), seed)
199195
))).toList
200196
}
201197

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.spark.mllib.util
1919

2020
import java.io.File
21+
import scala.math
22+
import scala.util.Random
2123

2224
import org.scalatest.FunSuite
2325

@@ -136,19 +138,30 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
136138
new LinearRegressionModel(Array(1.0), 0)
137139
}
138140

139-
test("kfoldRdd") {
141+
test("kFold") {
140142
val data = sc.parallelize(1 to 100, 2)
141143
val collectedData = data.collect().sorted
142-
val twoFoldedRdd = MLUtils.kFoldRdds(data, 2, 1)
144+
val twoFoldedRdd = MLUtils.kFold(data, 2, 1)
143145
assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
144146
assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
145147
for (folds <- 2 to 10) {
146148
for (seed <- 1 to 5) {
147-
val foldedRdds = MLUtils.kFoldRdds(data, folds, seed)
149+
val foldedRdds = MLUtils.kFold(data, folds, seed)
148150
assert(foldedRdds.size === folds)
149151
foldedRdds.map{case (test, train) =>
150152
val result = test.union(train).collect().sorted
151-
assert(test.collect().size > 0, "Non empty test data")
153+
val testSize = test.collect().size.toFloat
154+
assert(testSize > 0, "Non empty test data")
155+
val p = 1 / folds.toFloat
156+
// Within 3 standard deviations of the mean
157+
val range = 3 * math.sqrt(100 * p * (1-p))
158+
val expected = 100 * p
159+
val lowerBound = expected - range
160+
val upperBound = expected + range
161+
assert(testSize > lowerBound,
162+
"Test data (" + testSize + ") smaller than expected (" + lowerBound +")" )
163+
assert(testSize < upperBound,
164+
"Test data (" + testSize + ") larger than expected (" + upperBound +")" )
152165
assert(train.collect().size > 0, "Non empty training data")
153166
assert(result === collectedData,
154167
"Each training+test set combined contains all of the data")

0 commit comments

Comments
 (0)