17
17
18
18
package org .apache .spark .mllib .regression
19
19
20
- import java .io .File
21
- import java .nio .charset .Charset
22
-
23
20
import scala .collection .mutable .ArrayBuffer
24
21
25
- import com .google .common .io .Files
26
22
import org .scalatest .FunSuite
27
23
28
24
import org .apache .spark .mllib .linalg .Vectors
29
- import org .apache .spark .mllib .util .{LinearDataGenerator , LocalSparkContext }
30
- import org .apache .spark .streaming .{Milliseconds , StreamingContext }
31
- import org .apache .spark .util .Utils
25
+ import org .apache .spark .mllib .util .LinearDataGenerator
26
+ import org .apache .spark .streaming .dstream .DStream
27
+ import org .apache .spark .streaming .TestSuiteBase
28
+
29
+ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
32
30
33
- class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
31
+ // use longer wait time to ensure job completion
32
+ override def maxWaitTimeMillis = 20000
34
33
35
34
// Assert that two values are equal within tolerance epsilon
36
35
def assertEqual (v1 : Double , v2 : Double , epsilon : Double ) {
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
49
48
}
50
49
51
50
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
52
- test(" streaming linear regression parameter accuracy" ) {
51
+ test(" parameter accuracy" ) {
53
52
54
- val testDir = Files .createTempDir()
55
- val numBatches = 10
56
- val batchDuration = Milliseconds (1000 )
57
- val ssc = new StreamingContext (sc, batchDuration)
58
- val data = ssc.textFileStream(testDir.toString).map(LabeledPoint .parse)
53
+ // create model
59
54
val model = new StreamingLinearRegressionWithSGD ()
60
55
.setInitialWeights(Vectors .dense(0.0 , 0.0 ))
61
56
.setStepSize(0.1 )
62
- .setNumIterations(50 )
57
+ .setNumIterations(25 )
63
58
64
- model.trainOn(data)
65
-
66
- ssc.start()
67
-
68
- // write data to a file stream
69
- for (i <- 0 until numBatches) {
70
- val samples = LinearDataGenerator .generateLinearInput(
71
- 0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
72
- val file = new File (testDir, i.toString)
73
- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
74
- Thread .sleep(batchDuration.milliseconds)
59
+ // generate sequence of simulated data
60
+ val numBatches = 10
61
+ val input = (0 until numBatches).map { i =>
62
+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
75
63
}
76
64
77
- ssc.stop(stopSparkContext= false )
78
-
79
- System .clearProperty(" spark.driver.port" )
80
- Utils .deleteRecursively(testDir)
65
+ // apply model training to input stream
66
+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
67
+ model.trainOn(inputDStream)
68
+ inputDStream.count()
69
+ })
70
+ runStreams(ssc, numBatches, numBatches)
81
71
82
72
// check accuracy of final parameter estimates
83
73
assertEqual(model.latestModel().intercept, 0.0 , 0.1 )
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
91
81
}
92
82
93
83
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
94
- test(" streaming linear regression parameter convergence" ) {
84
+ test(" parameter convergence" ) {
95
85
96
- val testDir = Files .createTempDir()
97
- val batchDuration = Milliseconds (2000 )
98
- val ssc = new StreamingContext (sc, batchDuration)
99
- val numBatches = 5
100
- val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint .parse)
86
+ // create model
101
87
val model = new StreamingLinearRegressionWithSGD ()
102
88
.setInitialWeights(Vectors .dense(0.0 ))
103
89
.setStepSize(0.1 )
104
- .setNumIterations(50 )
105
-
106
- model.trainOn(data)
107
-
108
- ssc.start()
90
+ .setNumIterations(25 )
109
91
110
- // write data to a file stream
111
- val history = new ArrayBuffer [Double ](numBatches)
112
- for (i <- 0 until numBatches) {
113
- val samples = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
114
- val file = new File (testDir, i.toString)
115
- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
116
- Thread .sleep(batchDuration.milliseconds)
117
- // wait an extra few seconds to make sure the update finishes before new data arrive
118
- Thread .sleep(4000 )
119
- history.append(math.abs(model.latestModel().weights(0 ) - 10.0 ))
92
+ // generate sequence of simulated data
93
+ val numBatches = 10
94
+ val input = (0 until numBatches).map { i =>
95
+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
120
96
}
121
97
122
- ssc.stop(stopSparkContext= false )
98
+ // create buffer to store intermediate fits
99
+ val history = new ArrayBuffer [Double ](numBatches)
123
100
124
- System .clearProperty(" spark.driver.port" )
125
- Utils .deleteRecursively(testDir)
101
+ // apply model training to input stream, storing the intermediate results
102
+ // (we add a count to ensure the result is a DStream)
103
+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
104
+ model.trainOn(inputDStream)
105
+ inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0 ) - 10.0 )))
106
+ inputDStream.count()
107
+ })
108
+ runStreams(ssc, numBatches, numBatches)
126
109
110
+ // compute change in error
127
111
val deltas = history.drop(1 ).zip(history.dropRight(1 ))
128
112
// check error stability (it always either shrinks, or increases with small tol)
129
113
assert(deltas.forall(x => (x._1 - x._2) <= 0.1 ))
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
132
116
133
117
}
134
118
119
+ // Test predictions on a stream
120
+ test(" predictions" ) {
121
+
122
+ // create model initialized with true weights
123
+ val model = new StreamingLinearRegressionWithSGD ()
124
+ .setInitialWeights(Vectors .dense(10.0 , 10.0 ))
125
+ .setStepSize(0.1 )
126
+ .setNumIterations(25 )
127
+
128
+ // generate sequence of simulated data for testing
129
+ val numBatches = 10
130
+ val nPoints = 100
131
+ val testInput = (0 until numBatches).map { i =>
132
+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
133
+ }
134
+
135
+ // apply model predictions to test stream
136
+ val ssc = setupStreams(testInput, (inputDStream : DStream [LabeledPoint ]) => {
137
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
138
+ })
139
+ // collect the output as (true, estimated) tuples
140
+ val output : Seq [Seq [(Double , Double )]] = runStreams(ssc, numBatches, numBatches)
141
+
142
+ // compute the mean absolute error and check that it's always less than 0.1
143
+ val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
144
+ assert(errors.forall(x => x <= 0.1 ))
145
+
146
+ }
147
+
135
148
}
0 commit comments