|
18 | 18 | package org.apache.spark.mllib.util
|
19 | 19 |
|
20 | 20 | import java.io.File
|
| 21 | +import scala.math |
| 22 | +import scala.util.Random |
21 | 23 |
|
22 | 24 | import org.scalatest.FunSuite
|
23 | 25 |
|
@@ -136,19 +138,30 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
|
136 | 138 | new LinearRegressionModel(Array(1.0), 0)
|
137 | 139 | }
|
138 | 140 |
|
139 |
| - test("kfoldRdd") { |
| 141 | + test("kFold") { |
140 | 142 | val data = sc.parallelize(1 to 100, 2)
|
141 | 143 | val collectedData = data.collect().sorted
|
142 |
| - val twoFoldedRdd = MLUtils.kFoldRdds(data, 2, 1) |
| 144 | + val twoFoldedRdd = MLUtils.kFold(data, 2, 1) |
143 | 145 | assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
|
144 | 146 | assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
|
145 | 147 | for (folds <- 2 to 10) {
|
146 | 148 | for (seed <- 1 to 5) {
|
147 |
| - val foldedRdds = MLUtils.kFoldRdds(data, folds, seed) |
| 149 | + val foldedRdds = MLUtils.kFold(data, folds, seed) |
148 | 150 | assert(foldedRdds.size === folds)
|
149 | 151 | foldedRdds.map{case (test, train) =>
|
150 | 152 | val result = test.union(train).collect().sorted
|
151 |
| - assert(test.collect().size > 0, "Non empty test data") |
| 153 | + val testSize = test.collect().size.toFloat |
| 154 | + assert(testSize > 0, "Non empty test data") |
| 155 | + val p = 1 / folds.toFloat |
| 156 | + // Within 3 standard deviations of the mean |
| 157 | + val range = 3 * math.sqrt(100 * p * (1-p)) |
| 158 | + val expected = 100 * p |
| 159 | + val lowerBound = expected - range |
| 160 | + val upperBound = expected + range |
| 161 | + assert(testSize > lowerBound, |
| 162 | + "Test data (" + testSize + ") smaller than expected (" + lowerBound +")" ) |
| 163 | + assert(testSize < upperBound, |
| 164 | + "Test data (" + testSize + ") larger than expected (" + upperBound +")" ) |
152 | 165 | assert(train.collect().size > 0, "Non empty training data")
|
153 | 166 | assert(result === collectedData,
|
154 | 167 | "Each training+test set combined contains all of the data")
|
|
0 commit comments