@@ -23,18 +23,21 @@ import org.apache.spark.SparkContext._
23
23
24
24
/**
25
25
* Evaluator for multiclass classification.
26
+ * NB: type Double both for prediction and label is retained
27
+ * for compatibility with model.predict that returns Double
28
+ * and MLUtils.loadLibSVMFile that loads class labels as Double
26
29
*
27
- * @param scoreAndLabels an RDD of (score , label) pairs.
30
+ * @param predictionsAndLabels an RDD of (prediction , label) pairs.
28
31
*/
29
- class MulticlassMetrics (scoreAndLabels : RDD [(Double , Double )]) extends Logging {
32
+ class MulticlassMetrics (predictionsAndLabels : RDD [(Double , Double )]) extends Logging {
30
33
31
34
/* class = category; label = instance of class; prediction = instance of class */
32
35
33
- private lazy val labelCountByClass = scoreAndLabels .values.countByValue()
36
+ private lazy val labelCountByClass = predictionsAndLabels .values.countByValue()
34
37
private lazy val labelCount = labelCountByClass.foldLeft(0L ){case (sum, (_, count)) => sum + count}
35
- private lazy val tpByClass = scoreAndLabels .map{ case (prediction, label) =>
38
+ private lazy val tpByClass = predictionsAndLabels .map{ case (prediction, label) =>
36
39
(label, if (label == prediction) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
37
- private lazy val fpByClass = scoreAndLabels .map{ case (prediction, label) =>
40
+ private lazy val fpByClass = predictionsAndLabels .map{ case (prediction, label) =>
38
41
(prediction, if (prediction != label) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
39
42
40
43
/**
0 commit comments