Skip to content

Commit e23c2e5

Browse files
committed
added regression support
Signed-off-by: Manish Amde <[email protected]>
1 parent c8f6d60 commit e23c2e5

File tree

7 files changed

+231
-76
lines changed

7 files changed

+231
-76
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 194 additions & 67 deletions
Large diffs are not rendered by default.

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
package org.apache.spark.mllib.tree
1818

19+
import org.apache.spark.SparkContext._
1920
import org.apache.spark.{Logging, SparkContext}
2021
import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance}
2122
import org.apache.spark.rdd.RDD
@@ -95,6 +96,9 @@ object DecisionTreeRunner extends Logging {
9596
val accuracy = accuracyScore(model, testData)
9697
logDebug("accuracy = " + accuracy)
9798

99+
val mse = meanSquaredError(model,testData)
100+
logDebug("mean square error = " + mse)
101+
98102
sc.stop()
99103
}
100104

@@ -126,6 +130,14 @@ object DecisionTreeRunner extends Logging {
126130
correctCount.toDouble / count
127131
}
128132

133+
//TODO: Make these generic MLTable metrics
134+
def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = {
135+
val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean()
136+
println("meanSumOfSquares = " + meanSumOfSquares)
137+
meanSumOfSquares
138+
}
139+
140+
129141

130142

131143
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.spark.mllib.tree.impurity
1818

19+
import javax.naming.OperationNotSupportedException
20+
1921
object Entropy extends Impurity {
2022

2123
def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
@@ -31,4 +33,6 @@ object Entropy extends Impurity {
3133
}
3234
}
3335

34-
}
36+
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
37+
throw new OperationNotSupportedException("Entropy.calculate")
38+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.spark.mllib.tree.impurity
1818

19+
import javax.naming.OperationNotSupportedException
20+
1921
object Gini extends Impurity {
2022

2123
def calculate(c0 : Double, c1 : Double): Double = {
@@ -29,4 +31,5 @@ object Gini extends Impurity {
2931
}
3032
}
3133

32-
}
34+
def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new OperationNotSupportedException("Gini.calculate")
35+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ trait Impurity extends Serializable {
2020

2121
def calculate(c0 : Double, c1 : Double): Double
2222

23+
def calculate(count : Double, sum : Double, sumSquares : Double) : Double
24+
2325
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
package org.apache.spark.mllib.tree.impurity
1818

1919
import javax.naming.OperationNotSupportedException
20+
import org.apache.spark.Logging
2021

21-
object Variance extends Impurity {
22+
object Variance extends Impurity with Logging {
2223
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")
23-
}
24+
25+
def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
26+
val squaredLoss = sumSquares - (sum*sum)/count
27+
squaredLoss/count
28+
}
29+
30+
}

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
4949
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
5050
assert(arr.length == 1000)
5151
val rdd = sc.parallelize(arr)
52-
val strategy = new Strategy(Regression,Gini,3,100)
52+
val strategy = new Strategy(Classification,Gini,3,100)
5353
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
5454
assert(splits.length==2)
5555
assert(bins.length==2)
@@ -62,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6262
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
6363
assert(arr.length == 1000)
6464
val rdd = sc.parallelize(arr)
65-
val strategy = new Strategy(Regression,Gini,3,100)
65+
val strategy = new Strategy(Classification,Gini,3,100)
6666
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
6767
assert(splits.length==2)
6868
assert(splits(0).length==99)
@@ -88,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
8888
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
8989
assert(arr.length == 1000)
9090
val rdd = sc.parallelize(arr)
91-
val strategy = new Strategy(Regression,Gini,3,100)
91+
val strategy = new Strategy(Classification,Gini,3,100)
9292
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
9393
assert(splits.length==2)
9494
assert(splits(0).length==99)
@@ -114,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
114114
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
115115
assert(arr.length == 1000)
116116
val rdd = sc.parallelize(arr)
117-
val strategy = new Strategy(Regression,Entropy,3,100)
117+
val strategy = new Strategy(Classification,Entropy,3,100)
118118
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
119119
assert(splits.length==2)
120120
assert(splits(0).length==99)
@@ -139,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
139139
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
140140
assert(arr.length == 1000)
141141
val rdd = sc.parallelize(arr)
142-
val strategy = new Strategy(Regression,Entropy,3,100)
142+
val strategy = new Strategy(Classification,Entropy,3,100)
143143
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
144144
assert(splits.length==2)
145145
assert(splits(0).length==99)

0 commit comments

Comments
 (0)