Skip to content

Commit a7e8bf0

Browse files
committed
Addressing reviewers comments mengxr
1 parent c3a77ad commit a7e8bf0

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,27 @@ import org.apache.spark.rdd.RDD
2222
import org.apache.spark.Logging
2323
import org.apache.spark.SparkContext._
2424

25+
import scala.collection.Map
26+
2527
/**
28+
* ::Experimental::
2629
* Evaluator for multiclass classification.
2730
*
2831
* @param predictionsAndLabels an RDD of (prediction, label) pairs.
2932
*/
3033
@Experimental
3134
class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging {
3235

33-
private lazy val labelCountByClass = predictionsAndLabels.values.countByValue()
34-
private lazy val labelCount = labelCountByClass.values.sum
35-
private lazy val tpByClass = predictionsAndLabels
36-
.map{ case (prediction, label) =>
37-
(label, if (label == prediction) 1 else 0)
36+
private lazy val labelCountByClass: Map[Double, Long] = predictionsAndLabels.values.countByValue()
37+
private lazy val labelCount: Long = labelCountByClass.values.sum
38+
private lazy val tpByClass: Map[Double, Int] = predictionsAndLabels
39+
.map { case (prediction, label) =>
40+
(label, if (label == prediction) 1 else 0)
3841
}.reduceByKey(_ + _)
3942
.collectAsMap()
40-
private lazy val fpByClass = predictionsAndLabels
41-
.map{ case (prediction, label) =>
42-
(prediction, if (prediction != label) 1 else 0)
43+
private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels
44+
.map { case (prediction, label) =>
45+
(prediction, if (prediction != label) 1 else 0)
4346
}.reduceByKey(_ + _)
4447
.collectAsMap()
4548

@@ -63,35 +66,41 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
6366
* Returns f-measure for a given label (category)
6467
* @param label the label.
6568
*/
66-
def fMeasure(label: Double, beta:Double = 1.0): Double = {
69+
def fMeasure(label: Double, beta: Double): Double = {
6770
val p = precision(label)
6871
val r = recall(label)
6972
val betaSqrd = beta * beta
7073
if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
7174
}
7275

7376
/**
74-
* Returns micro-averaged recall
75-
* (equals to microPrecision and microF1measure for multiclass classifier)
77+
* Returns f1-measure for a given label (category)
78+
* @param label the label.
79+
*/
80+
def fMeasure(label: Double): Double = fMeasure(label, 1.0)
81+
82+
/**
83+
* Returns precision
7684
*/
77-
lazy val recall: Double =
78-
tpByClass.values.sum.toDouble / labelCount
85+
lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount
7986

8087
/**
81-
* Returns micro-averaged precision
82-
* (equals to microPrecision and microF1measure for multiclass classifier)
88+
* Returns recall
89+
* (equals to precision for multiclass classifier
90+
* because sum of all false positives is equal to sum
91+
* of all false negatives)
8392
*/
84-
lazy val precision: Double = recall
93+
lazy val recall: Double = precision
8594

8695
/**
87-
* Returns micro-averaged f-measure
88-
* (equals to microPrecision and microRecall for multiclass classifier)
96+
* Returns f-measure
97+
* (equals to precision and recall because precision equals recall)
8998
*/
90-
lazy val fMeasure: Double = recall
99+
lazy val fMeasure: Double = precision
91100

92101
/**
93102
* Returns weighted averaged recall
94-
* (equals to micro-averaged precision, recall and f-measure)
103+
* (equals to precision, recall and f-measure)
95104
*/
96105
lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) =>
97106
recall(category) * count.toDouble / labelCount
@@ -114,6 +123,5 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
114123
/**
115124
* Returns the sequence of labels in ascending order
116125
*/
117-
lazy val labels = tpByClass.unzip._1.toSeq.sorted
118-
126+
lazy val labels:Array[Double] = tpByClass.keys.toArray.sorted
119127
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
3030
* |0|0|1| true class2 (1 instance)
3131
*
3232
*/
33-
val labels = Seq(0.0, 1.0, 2.0)
33+
val labels = Array(0.0, 1.0, 2.0)
3434
val scoreAndLabels = sc.parallelize(
3535
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
3636
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
@@ -65,6 +65,6 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
6565
((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
6666
assert(math.abs(metrics.weightedF1Measure -
6767
((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
68-
assert(metrics.labels == labels)
68+
assert(metrics.labels.sameElements(labels))
6969
}
7070
}

0 commit comments

Comments
 (0)