Skip to content

Commit e2c91c3

Browse files
committed
Fixes to mutliclass metics
1 parent d5ce981 commit e2c91c3

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,72 +60,75 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
6060
* @param label the label.
6161
* @return F1-measure.
6262
*/
63-
def f1Measure(label: Double): Double =
64-
2 * precision(label) * recall(label) / (precision(label) + recall(label))
63+
def f1Measure(label: Double): Double ={
64+
val p = precision(label)
65+
val r = recall(label)
66+
if((p + r) == 0) 0 else 2 * p * r / (p + r)
67+
}
6568

6669
/**
6770
* Returns micro-averaged Recall
6871
* (equals to microPrecision and microF1measure for multiclass classifier)
6972
* @return microRecall.
7073
*/
71-
def microRecall: Double =
72-
tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble
74+
lazy val microRecall: Double =
75+
tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount
7376

7477
/**
7578
* Returns micro-averaged Precision
7679
* (equals to microPrecision and microF1measure for multiclass classifier)
7780
* @return microPrecision.
7881
*/
79-
def microPrecision: Double = microRecall
82+
lazy val microPrecision: Double = microRecall
8083

8184
/**
8285
* Returns micro-averaged F1-measure
8386
* (equals to microPrecision and microRecall for multiclass classifier)
8487
* @return microF1measure.
8588
*/
86-
def microF1Measure: Double = microRecall
89+
lazy val microF1Measure: Double = microRecall
8790

8891
/**
8992
* Returns weighted averaged Recall
9093
* @return weightedRecall.
9194
*/
92-
def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
93-
wRecall + recall(category) * count.toDouble / labelCount.toDouble}
95+
lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
96+
wRecall + recall(category) * count.toDouble / labelCount}
9497

9598
/**
9699
* Returns weighted averaged Precision
97100
* @return weightedPrecision.
98101
*/
99-
def weightedPrecision: Double =
102+
lazy val weightedPrecision: Double =
100103
labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) =>
101-
wPrecision + precision(category) * count.toDouble / labelCount.toDouble}
104+
wPrecision + precision(category) * count.toDouble / labelCount}
102105

103106
/**
104107
* Returns weighted averaged F1-measure
105108
* @return weightedF1Measure.
106109
*/
107-
def weightedF1Measure: Double =
110+
lazy val weightedF1Measure: Double =
108111
labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) =>
109-
wF1measure + f1Measure(category) * count.toDouble / labelCount.toDouble}
112+
wF1measure + f1Measure(category) * count.toDouble / labelCount}
110113

111114
/**
112115
* Returns map with Precisions for individual classes
113116
* @return precisionPerClass.
114117
*/
115-
def precisionPerClass =
118+
lazy val precisionPerClass =
116119
labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
117120

118121
/**
119122
* Returns map with Recalls for individual classes
120123
* @return recallPerClass.
121124
*/
122-
def recallPerClass =
125+
lazy val recallPerClass =
123126
labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
124127

125128
/**
126129
* Returns map with F1-measures for individual classes
127130
* @return f1MeasurePerClass.
128131
*/
129-
def f1MeasurePerClass =
132+
lazy val f1MeasurePerClass =
130133
labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
131134
}

0 commit comments

Comments
 (0)