Skip to content

Commit 31f0b07

Browse files
freeman-labtdas
authored andcommitted
[SPARK-3128][MLLIB] Use streaming test suite for StreamingLR
Refactored tests for streaming linear regression to use existing streaming test utilities. Summary of changes: - Made ``mllib`` depend on tests from ``streaming`` - Rewrote accuracy and convergence tests to use ``setupStreams`` and ``runStreams`` - Added new test for the accuracy of predictions generated by ``predictOnValue`` These tests should run faster, be easier to extend/maintain, and provide a reference for new tests. mengxr tdas Author: freeman <[email protected]> Closes apache#2037 from freeman-lab/streamingLR-predict-tests and squashes the following commits: e851ca7 [freeman] Fixed long lines 50eb0bf [freeman] Refactored tests to use streaming test tools 32c43c2 [freeman] Added test for prediction
1 parent cbfc26b commit 31f0b07

File tree

3 files changed

+77
-55
lines changed

3 files changed

+77
-55
lines changed

mllib/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@
9191
<artifactId>junit-interface</artifactId>
9292
<scope>test</scope>
9393
</dependency>
94+
<dependency>
95+
<groupId>org.apache.spark</groupId>
96+
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
97+
<version>${project.version}</version>
98+
<type>test-jar</type>
99+
<scope>test</scope>
100+
</dependency>
94101
</dependencies>
95102
<profiles>
96103
<profile>

mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import java.io.File
21-
import java.nio.charset.Charset
22-
2320
import scala.collection.mutable.ArrayBuffer
2421

25-
import com.google.common.io.Files
2622
import org.scalatest.FunSuite
2723

2824
import org.apache.spark.mllib.linalg.Vectors
29-
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
30-
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
31-
import org.apache.spark.util.Utils
25+
import org.apache.spark.mllib.util.LinearDataGenerator
26+
import org.apache.spark.streaming.dstream.DStream
27+
import org.apache.spark.streaming.TestSuiteBase
28+
29+
class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
3230

33-
class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
31+
// use longer wait time to ensure job completion
32+
override def maxWaitTimeMillis = 20000
3433

3534
// Assert that two values are equal within tolerance epsilon
3635
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
4948
}
5049

5150
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
52-
test("streaming linear regression parameter accuracy") {
51+
test("parameter accuracy") {
5352

54-
val testDir = Files.createTempDir()
55-
val numBatches = 10
56-
val batchDuration = Milliseconds(1000)
57-
val ssc = new StreamingContext(sc, batchDuration)
58-
val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse)
53+
// create model
5954
val model = new StreamingLinearRegressionWithSGD()
6055
.setInitialWeights(Vectors.dense(0.0, 0.0))
6156
.setStepSize(0.1)
62-
.setNumIterations(50)
57+
.setNumIterations(25)
6358

64-
model.trainOn(data)
65-
66-
ssc.start()
67-
68-
// write data to a file stream
69-
for (i <- 0 until numBatches) {
70-
val samples = LinearDataGenerator.generateLinearInput(
71-
0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
72-
val file = new File(testDir, i.toString)
73-
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
74-
Thread.sleep(batchDuration.milliseconds)
59+
// generate sequence of simulated data
60+
val numBatches = 10
61+
val input = (0 until numBatches).map { i =>
62+
LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
7563
}
7664

77-
ssc.stop(stopSparkContext=false)
78-
79-
System.clearProperty("spark.driver.port")
80-
Utils.deleteRecursively(testDir)
65+
// apply model training to input stream
66+
val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
67+
model.trainOn(inputDStream)
68+
inputDStream.count()
69+
})
70+
runStreams(ssc, numBatches, numBatches)
8171

8272
// check accuracy of final parameter estimates
8373
assertEqual(model.latestModel().intercept, 0.0, 0.1)
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
9181
}
9282

9383
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
94-
test("streaming linear regression parameter convergence") {
84+
test("parameter convergence") {
9585

96-
val testDir = Files.createTempDir()
97-
val batchDuration = Milliseconds(2000)
98-
val ssc = new StreamingContext(sc, batchDuration)
99-
val numBatches = 5
100-
val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse)
86+
// create model
10187
val model = new StreamingLinearRegressionWithSGD()
10288
.setInitialWeights(Vectors.dense(0.0))
10389
.setStepSize(0.1)
104-
.setNumIterations(50)
105-
106-
model.trainOn(data)
107-
108-
ssc.start()
90+
.setNumIterations(25)
10991

110-
// write data to a file stream
111-
val history = new ArrayBuffer[Double](numBatches)
112-
for (i <- 0 until numBatches) {
113-
val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
114-
val file = new File(testDir, i.toString)
115-
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
116-
Thread.sleep(batchDuration.milliseconds)
117-
// wait an extra few seconds to make sure the update finishes before new data arrive
118-
Thread.sleep(4000)
119-
history.append(math.abs(model.latestModel().weights(0) - 10.0))
92+
// generate sequence of simulated data
93+
val numBatches = 10
94+
val input = (0 until numBatches).map { i =>
95+
LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
12096
}
12197

122-
ssc.stop(stopSparkContext=false)
98+
// create buffer to store intermediate fits
99+
val history = new ArrayBuffer[Double](numBatches)
123100

124-
System.clearProperty("spark.driver.port")
125-
Utils.deleteRecursively(testDir)
101+
// apply model training to input stream, storing the intermediate results
102+
// (we add a count to ensure the result is a DStream)
103+
val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
104+
model.trainOn(inputDStream)
105+
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
106+
inputDStream.count()
107+
})
108+
runStreams(ssc, numBatches, numBatches)
126109

110+
// compute change in error
127111
val deltas = history.drop(1).zip(history.dropRight(1))
128112
// check error stability (it always either shrinks, or increases with small tol)
129113
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
132116

133117
}
134118

119+
// Test predictions on a stream
120+
test("predictions") {
121+
122+
// create model initialized with true weights
123+
val model = new StreamingLinearRegressionWithSGD()
124+
.setInitialWeights(Vectors.dense(10.0, 10.0))
125+
.setStepSize(0.1)
126+
.setNumIterations(25)
127+
128+
// generate sequence of simulated data for testing
129+
val numBatches = 10
130+
val nPoints = 100
131+
val testInput = (0 until numBatches).map { i =>
132+
LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
133+
}
134+
135+
// apply model predictions to test stream
136+
val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
137+
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
138+
})
139+
// collect the output as (true, estimated) tuples
140+
val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
141+
142+
// compute the mean absolute error and check that it's always less than 0.1
143+
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
144+
assert(errors.forall(x => x <= 0.1))
145+
146+
}
147+
135148
}

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
242242
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
243243

244244
// Get the output buffer
245-
val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
245+
val outputStream = ssc.graph.getOutputStreams.
246+
filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]).
247+
head.asInstanceOf[TestOutputStreamWithPartitions[V]]
246248
val output = outputStream.output
247249

248250
try {

0 commit comments

Comments
 (0)