Skip to content

Commit 88aacc8

Browse files
committed
Add assert to testcase on cluster sizes
1 parent 24f438e commit 88aacc8

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

mllib/src/test/scala/org/apache/spark/mllib/clustering/PIClusteringSuite.scala

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import org.apache.log4j.Logger
2021
import org.apache.spark.{SparkConf, SparkContext}
2122
import org.apache.spark.graphx._
2223
import org.apache.spark.mllib.clustering.PICLinalg.DMatrix
@@ -27,6 +28,8 @@ import scala.util.Random
2728

2829
class PIClusteringSuite extends FunSuite with LocalSparkContext {
2930

31+
val logger = Logger.getLogger(getClass.getName)
32+
3033
import org.apache.spark.mllib.clustering.PIClusteringSuite._
3134

3235
val PIC = PIClustering
@@ -38,6 +41,7 @@ class PIClusteringSuite extends FunSuite with LocalSparkContext {
3841
concentricCirclesTest()
3942
}
4043

44+
4145
def concentricCirclesTest() = {
4246
val sigma = 1.0
4347
val nIterations = 10
@@ -63,33 +67,13 @@ class PIClusteringSuite extends FunSuite with LocalSparkContext {
6367
val (ccenters, estCollected) = PIC.run(sc, vertices, nClusters, nIterations)
6468
println(s"Cluster centers: ${ccenters.mkString(",")} " +
6569
s"\nEstimates: ${estCollected.mkString("[", ",", "]")}")
66-
assert(ccenters.size == circleSpecs.length,"Did not get correct number of centers")
67-
val clustGroupsList = estCollected.groupBy{ case ((vid, eigenV), clustNum) =>
68-
clustNum
69-
}.mapValues{
70-
_.map{ case ((vid, eigenV), clustNum) =>
71-
(vid, clustNum)
72-
}}.toList.sortBy(_._1)
73-
74-
75-
val ccentersOrdered = ccenters.sortBy(-1.0 * _._2(0))
76-
77-
// val joinedGroups = ccentersOrdered.(clustGroupsList.toMap)
78-
//
79-
// val clustValids = clustGroupsList.map{ case (clustNum, vidEigensList) =>
80-
// (clustNum, vidEigensList.size, vidEigensList.map{ (_._1 / 1000).toLong }}
81-
// assert(clustGroups.map{_._2.size} == circleSpecs.map{ p => p.nPoints },
82-
// "Incorrect match on clusterGroupsSize")
83-
// val matchedCentersAndPoints = ccentersOrdered.map{ case (groupId, loc) => groupId}.zip(clustGroups)
84-
// assert(matchedCentersAndPoints.map{_._2.size} == circleSpecs.map{ p => p.nPoints },
85-
// "Incorrect match on clusterGroupsSize
86-
//
87-
// assert(estCollected == circleSpecs.length,"Did not get correct number of centers")
70+
assert(ccenters.size == circleSpecs.length, "Did not get correct number of centers")
71+
8872
}
8973
}
9074

91-
def join[T <: Comparable[T]](a: Map[T,_], b: Map[T,_]) = {
92-
(a.toSeq++b.toSeq).groupBy(_._1).mapValues(_.map(_._2).toList)
75+
def join[T <: Comparable[T]](a: Map[T, _], b: Map[T, _]) = {
76+
(a.toSeq ++ b.toSeq).groupBy(_._1).mapValues(_.map(_._2).toList)
9377
}
9478

9579
ignore("irisData") {

0 commit comments

Comments
 (0)