@@ -22,24 +22,27 @@ import org.apache.spark.rdd.RDD
22
22
import org .apache .spark .Logging
23
23
import org .apache .spark .SparkContext ._
24
24
25
+ import scala .collection .Map
26
+
25
27
/**
28
+ * ::Experimental::
26
29
* Evaluator for multiclass classification.
27
30
*
28
31
* @param predictionsAndLabels an RDD of (prediction, label) pairs.
29
32
*/
30
33
@ Experimental
31
34
class MulticlassMetrics (predictionsAndLabels : RDD [(Double , Double )]) extends Logging {
32
35
33
- private lazy val labelCountByClass = predictionsAndLabels.values.countByValue()
34
- private lazy val labelCount = labelCountByClass.values.sum
35
- private lazy val tpByClass = predictionsAndLabels
36
- .map{ case (prediction, label) =>
37
- (label, if (label == prediction) 1 else 0 )
36
+ private lazy val labelCountByClass : Map [ Double , Long ] = predictionsAndLabels.values.countByValue()
37
+ private lazy val labelCount : Long = labelCountByClass.values.sum
38
+ private lazy val tpByClass : Map [ Double , Int ] = predictionsAndLabels
39
+ .map { case (prediction, label) =>
40
+ (label, if (label == prediction) 1 else 0 )
38
41
}.reduceByKey(_ + _)
39
42
.collectAsMap()
40
- private lazy val fpByClass = predictionsAndLabels
41
- .map{ case (prediction, label) =>
42
- (prediction, if (prediction != label) 1 else 0 )
43
+ private lazy val fpByClass : Map [ Double , Int ] = predictionsAndLabels
44
+ .map { case (prediction, label) =>
45
+ (prediction, if (prediction != label) 1 else 0 )
43
46
}.reduceByKey(_ + _)
44
47
.collectAsMap()
45
48
@@ -63,35 +66,41 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
63
66
* Returns f-measure for a given label (category)
64
67
* @param label the label.
65
68
*/
66
- def fMeasure (label : Double , beta: Double = 1.0 ): Double = {
69
+ def fMeasure (label : Double , beta : Double ): Double = {
67
70
val p = precision(label)
68
71
val r = recall(label)
69
72
val betaSqrd = beta * beta
70
73
if (p + r == 0 ) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
71
74
}
72
75
73
76
/**
74
- * Returns micro-averaged recall
75
- * (equals to microPrecision and microF1measure for multiclass classifier)
77
+ * Returns f1-measure for a given label (category)
78
+ * @param label the label.
79
+ */
80
+ def fMeasure (label : Double ): Double = fMeasure(label, 1.0 )
81
+
82
+ /**
83
+ * Returns precision
76
84
*/
77
- lazy val recall : Double =
78
- tpByClass.values.sum.toDouble / labelCount
85
+ lazy val precision : Double = tpByClass.values.sum.toDouble / labelCount
79
86
80
87
/**
81
- * Returns micro-averaged precision
82
- * (equals to microPrecision and microF1measure for multiclass classifier)
88
+ * Returns recall
89
+ * (equals to precision for multiclass classifier
90
+ * because sum of all false positives is equal to sum
91
+ * of all false negatives)
83
92
*/
84
- lazy val precision : Double = recall
93
+ lazy val recall : Double = precision
85
94
86
95
/**
87
- * Returns micro-averaged f-measure
88
- * (equals to microPrecision and microRecall for multiclass classifier )
96
+ * Returns f-measure
97
+ * (equals to precision and recall because precision equals recall )
89
98
*/
90
- lazy val fMeasure : Double = recall
99
+ lazy val fMeasure : Double = precision
91
100
92
101
/**
93
102
* Returns weighted averaged recall
94
- * (equals to micro-averaged precision, recall and f-measure)
103
+ * (equals to precision, recall and f-measure)
95
104
*/
96
105
lazy val weightedRecall : Double = labelCountByClass.map { case (category, count) =>
97
106
recall(category) * count.toDouble / labelCount
@@ -114,6 +123,5 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
114
123
/**
115
124
* Returns the sequence of labels in ascending order
116
125
*/
117
- lazy val labels = tpByClass.unzip._1.toSeq.sorted
118
-
126
+ lazy val labels : Array [Double ] = tpByClass.keys.toArray.sorted
119
127
}
0 commit comments