Skip to content

Commit 1397ab5

Browse files
committed
use sqlContext from LocalSparkContext instead of TestSQLContext
1 parent 6ffc389 commit 1397ab5

File tree

3 files changed

+16
-33
lines changed

3 files changed

+16
-33
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,17 @@
1717

1818
package org.apache.spark.ml.classification
1919

20-
import org.scalatest.{BeforeAndAfterAll, FunSuite}
20+
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
23+
import org.apache.spark.mllib.util.LocalSparkContext
2324
import org.apache.spark.sql.SchemaRDD
24-
import org.apache.spark.sql.test.TestSQLContext._
2525

26-
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Serializable {
26+
class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
2727

28-
@transient var dataset: SchemaRDD = _
28+
import sqlContext._
2929

30-
override def beforeAll(): Unit = {
31-
super.beforeAll()
32-
val points = generateLogisticInput(1.0, 1.0, 100, 42)
33-
val rdd = sparkContext.parallelize(points)
34-
dataset = createSchemaRDD(rdd)
35-
}
30+
val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
3631

3732
test("logistic regression") {
3833
val lr = new LogisticRegression

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,19 @@
1717

1818
package org.apache.spark.ml.tuning
1919

20-
import org.scalatest.{BeforeAndAfterAll, FunSuite}
20+
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.ml.classification.LogisticRegression
2323
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
2424
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
25+
import org.apache.spark.mllib.util.LocalSparkContext
2526
import org.apache.spark.sql.SchemaRDD
26-
import org.apache.spark.sql.test.TestSQLContext._
2727

28-
class CrossValidatorSuite extends FunSuite with BeforeAndAfterAll with Serializable {
28+
class CrossValidatorSuite extends FunSuite with LocalSparkContext {
2929

30-
var dataset: SchemaRDD = _
30+
import sqlContext._
3131

32-
override def beforeAll(): Unit = {
33-
val points = generateLogisticInput(1.0, 1.0, 100, 42)
34-
dataset = sparkContext.parallelize(points, 2)
35-
}
32+
val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
3633

3734
test("cross validation with logistic regression") {
3835
val lr = new LogisticRegression

mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,17 @@
1717

1818
package org.apache.spark.mllib.util
1919

20-
import org.scalatest.Suite
21-
import org.scalatest.BeforeAndAfterAll
20+
import org.scalatest.{BeforeAndAfterAll, Suite}
2221

23-
import org.apache.spark.{SparkConf, SparkContext}
22+
import org.apache.spark.SparkContext
23+
import org.apache.spark.sql.SQLContext
2424

2525
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
26-
@transient var sc: SparkContext = _
27-
28-
override def beforeAll() {
29-
val conf = new SparkConf()
30-
.setMaster("local")
31-
.setAppName("test")
32-
sc = new SparkContext(conf)
33-
super.beforeAll()
34-
}
26+
@transient val sc = new SparkContext("local[2]", "test")
27+
@transient lazy val sqlContext = new SQLContext(sc)
3528

3629
override def afterAll() {
37-
if (sc != null) {
38-
sc.stop()
39-
}
30+
sc.stop()
4031
super.afterAll()
4132
}
4233
}

0 commit comments

Comments
 (0)