Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 98a46f9

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-6094] [MLLIB] Add MultilabelMetrics in PySpark/MLlib
Add MultilabelMetrics in PySpark/MLlib Author: Yanbo Liang <[email protected]> Closes apache#6276 from yanboliang/spark-6094 and squashes the following commits: b8e3343 [Yanbo Liang] Add MultilabelMetrics in PySpark/MLlib
1 parent 589b12f commit 98a46f9

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation
1919

2020
import org.apache.spark.rdd.RDD
2121
import org.apache.spark.SparkContext._
22+
import org.apache.spark.sql.DataFrame
2223

2324
/**
2425
* Evaluator for multilabel classification.
@@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._
2728
*/
2829
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
2930

31+
/**
32+
* An auxiliary constructor taking a DataFrame.
33+
* @param predictionAndLabels a DataFrame with two double array columns: prediction and label
34+
*/
35+
private[mllib] def this(predictionAndLabels: DataFrame) =
36+
this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray)))
37+
3038
private lazy val numDocs: Long = predictionAndLabels.count()
3139

3240
private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>

python/pyspark/mllib/evaluation.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,123 @@ def ndcgAt(self, k):
343343
return self.call("ndcgAt", int(k))
344344

345345

346+
class MultilabelMetrics(JavaModelWrapper):
347+
"""
348+
Evaluator for multilabel classification.
349+
350+
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
351+
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
352+
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
353+
>>> metrics = MultilabelMetrics(predictionAndLabels)
354+
>>> metrics.precision(0.0)
355+
1.0
356+
>>> metrics.recall(1.0)
357+
0.66...
358+
>>> metrics.f1Measure(2.0)
359+
0.5
360+
>>> metrics.precision()
361+
0.66...
362+
>>> metrics.recall()
363+
0.64...
364+
>>> metrics.f1Measure()
365+
0.63...
366+
>>> metrics.microPrecision
367+
0.72...
368+
>>> metrics.microRecall
369+
0.66...
370+
>>> metrics.microF1Measure
371+
0.69...
372+
>>> metrics.hammingLoss
373+
0.33...
374+
>>> metrics.subsetAccuracy
375+
0.28...
376+
>>> metrics.accuracy
377+
0.54...
378+
"""
379+
380+
def __init__(self, predictionAndLabels):
381+
sc = predictionAndLabels.ctx
382+
sql_ctx = SQLContext(sc)
383+
df = sql_ctx.createDataFrame(predictionAndLabels,
384+
schema=sql_ctx._inferSchema(predictionAndLabels))
385+
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
386+
java_model = java_class(df._jdf)
387+
super(MultilabelMetrics, self).__init__(java_model)
388+
389+
def precision(self, label=None):
390+
"""
391+
Returns precision or precision for a given label (category) if specified.
392+
"""
393+
if label is None:
394+
return self.call("precision")
395+
else:
396+
return self.call("precision", float(label))
397+
398+
def recall(self, label=None):
399+
"""
400+
Returns recall or recall for a given label (category) if specified.
401+
"""
402+
if label is None:
403+
return self.call("recall")
404+
else:
405+
return self.call("recall", float(label))
406+
407+
def f1Measure(self, label=None):
408+
"""
409+
Returns f1Measure or f1Measure for a given label (category) if specified.
410+
"""
411+
if label is None:
412+
return self.call("f1Measure")
413+
else:
414+
return self.call("f1Measure", float(label))
415+
416+
@property
417+
def microPrecision(self):
418+
"""
419+
Returns micro-averaged label-based precision.
420+
(equals to micro-averaged document-based precision)
421+
"""
422+
return self.call("microPrecision")
423+
424+
@property
425+
def microRecall(self):
426+
"""
427+
Returns micro-averaged label-based recall.
428+
(equals to micro-averaged document-based recall)
429+
"""
430+
return self.call("microRecall")
431+
432+
@property
433+
def microF1Measure(self):
434+
"""
435+
Returns micro-averaged label-based f1-measure.
436+
(equals to micro-averaged document-based f1-measure)
437+
"""
438+
return self.call("microF1Measure")
439+
440+
@property
441+
def hammingLoss(self):
442+
"""
443+
Returns Hamming-loss.
444+
"""
445+
return self.call("hammingLoss")
446+
447+
@property
448+
def subsetAccuracy(self):
449+
"""
450+
Returns subset accuracy.
451+
(for equal sets of labels)
452+
"""
453+
return self.call("subsetAccuracy")
454+
455+
@property
456+
def accuracy(self):
457+
"""
458+
Returns accuracy.
459+
"""
460+
return self.call("accuracy")
461+
462+
346463
def _test():
347464
import doctest
348465
from pyspark import SparkContext

0 commit comments

Comments
 (0)