@@ -343,6 +343,123 @@ def ndcgAt(self, k):
343
343
return self .call ("ndcgAt" , int (k ))
344
344
345
345
346
+ class MultilabelMetrics (JavaModelWrapper ):
347
+ """
348
+ Evaluator for multilabel classification.
349
+
350
+ >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
351
+ ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
352
+ ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
353
+ >>> metrics = MultilabelMetrics(predictionAndLabels)
354
+ >>> metrics.precision(0.0)
355
+ 1.0
356
+ >>> metrics.recall(1.0)
357
+ 0.66...
358
+ >>> metrics.f1Measure(2.0)
359
+ 0.5
360
+ >>> metrics.precision()
361
+ 0.66...
362
+ >>> metrics.recall()
363
+ 0.64...
364
+ >>> metrics.f1Measure()
365
+ 0.63...
366
+ >>> metrics.microPrecision
367
+ 0.72...
368
+ >>> metrics.microRecall
369
+ 0.66...
370
+ >>> metrics.microF1Measure
371
+ 0.69...
372
+ >>> metrics.hammingLoss
373
+ 0.33...
374
+ >>> metrics.subsetAccuracy
375
+ 0.28...
376
+ >>> metrics.accuracy
377
+ 0.54...
378
+ """
379
+
380
+ def __init__ (self , predictionAndLabels ):
381
+ sc = predictionAndLabels .ctx
382
+ sql_ctx = SQLContext (sc )
383
+ df = sql_ctx .createDataFrame (predictionAndLabels ,
384
+ schema = sql_ctx ._inferSchema (predictionAndLabels ))
385
+ java_class = sc ._jvm .org .apache .spark .mllib .evaluation .MultilabelMetrics
386
+ java_model = java_class (df ._jdf )
387
+ super (MultilabelMetrics , self ).__init__ (java_model )
388
+
389
+ def precision (self , label = None ):
390
+ """
391
+ Returns precision or precision for a given label (category) if specified.
392
+ """
393
+ if label is None :
394
+ return self .call ("precision" )
395
+ else :
396
+ return self .call ("precision" , float (label ))
397
+
398
+ def recall (self , label = None ):
399
+ """
400
+ Returns recall or recall for a given label (category) if specified.
401
+ """
402
+ if label is None :
403
+ return self .call ("recall" )
404
+ else :
405
+ return self .call ("recall" , float (label ))
406
+
407
+ def f1Measure (self , label = None ):
408
+ """
409
+ Returns f1Measure or f1Measure for a given label (category) if specified.
410
+ """
411
+ if label is None :
412
+ return self .call ("f1Measure" )
413
+ else :
414
+ return self .call ("f1Measure" , float (label ))
415
+
416
+ @property
417
+ def microPrecision (self ):
418
+ """
419
+ Returns micro-averaged label-based precision.
420
+ (equals to micro-averaged document-based precision)
421
+ """
422
+ return self .call ("microPrecision" )
423
+
424
+ @property
425
+ def microRecall (self ):
426
+ """
427
+ Returns micro-averaged label-based recall.
428
+ (equals to micro-averaged document-based recall)
429
+ """
430
+ return self .call ("microRecall" )
431
+
432
+ @property
433
+ def microF1Measure (self ):
434
+ """
435
+ Returns micro-averaged label-based f1-measure.
436
+ (equals to micro-averaged document-based f1-measure)
437
+ """
438
+ return self .call ("microF1Measure" )
439
+
440
+ @property
441
+ def hammingLoss (self ):
442
+ """
443
+ Returns Hamming-loss.
444
+ """
445
+ return self .call ("hammingLoss" )
446
+
447
+ @property
448
+ def subsetAccuracy (self ):
449
+ """
450
+ Returns subset accuracy.
451
+ (for equal sets of labels)
452
+ """
453
+ return self .call ("subsetAccuracy" )
454
+
455
+ @property
456
+ def accuracy (self ):
457
+ """
458
+ Returns accuracy.
459
+ """
460
+ return self .call ("accuracy" )
461
+
462
+
346
463
def _test ():
347
464
import doctest
348
465
from pyspark import SparkContext
0 commit comments