Skip to content

Commit f377431

Browse files
sboeschhuaweimengxr
authored andcommitted
[SPARK-4259][MLlib]: Add Power Iteration Clustering Algorithm with Gaussian Similarity Function
Add single pseudo-eigenvector PIC Including documentations and updated pom.xml with the following codes: mllib/src/main/scala/org/apache/spark/mllib/clustering/PIClustering.scala mllib/src/test/scala/org/apache/spark/mllib/clustering/PIClusteringSuite.scala Author: sboeschhuawei <[email protected]> Author: Fan Jiang <[email protected]> Author: Jiang Fan <[email protected]> Author: Stephen Boesch <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#4254 from fjiang6/PIC and squashes the following commits: 4550850 [sboeschhuawei] Removed pic test data f292f31 [Stephen Boesch] Merge pull request #44 from mengxr/SPARK-4259 4b78aaf [Xiangrui Meng] refactor PIC 24fbf52 [sboeschhuawei] Updated API to be similar to KMeans plus other changes requested by Xiangrui on the PR c12dfc8 [sboeschhuawei] Removed examples files and added pic_data.txt. Revamped testcases yet to come 92d4752 [sboeschhuawei] Move the Guassian/ Affinity matrix calcs out of PIC. Presently in the test suite 7ebd149 [sboeschhuawei] Incorporate Xiangrui's first set of PR comments except restructure PIC.run to take Graph but do not remove Gaussian 121e4d5 [sboeschhuawei] Remove unused testing data files 1c3a62e [sboeschhuawei] removed matplot.py and reordered all private methods to bottom of PIC 218a49d [sboeschhuawei] Applied Xiangrui's comments - especially removing RDD/PICLinalg classes and making noncritical methods private 43ab10b [sboeschhuawei] Change last two println's to log4j logger 88aacc8 [sboeschhuawei] Add assert to testcase on cluster sizes 24f438e [sboeschhuawei] fixed incorrect markdown in clustering doc 060e6bf [sboeschhuawei] Added link to PIC doc from the main clustering md doc be659e3 [sboeschhuawei] Added mllib specific log4j 90e7fa4 [sboeschhuawei] Converted from custom Linalg routines to Breeze: added JavaDoc comments; added Markdown documentation bea48ea [sboeschhuawei] Converted custom Linear Algebra datatypes/routines to use Breeze. b29c0db [Fan Jiang] Update PIClustering.scala ace9749 [Fan Jiang] Update PIClustering.scala a112f38 [sboeschhuawei] Added graphx main and test jars as dependencies to mllib/pom.xml f656c34 [sboeschhuawei] Added iris dataset b7dbcbe [sboeschhuawei] Added axes and combined into single plot for matplotlib a2b1e57 [sboeschhuawei] Revert inadvertent update to KMeans 9294263 [sboeschhuawei] Added visualization/plotting of input/output data e5df2b8 [sboeschhuawei] First end to end working PIC 0700335 [sboeschhuawei] First end to end working version: but has bad performance issue 32a90dc [sboeschhuawei] Update circles test data values 0ef163f [sboeschhuawei] Added ConcentricCircles data generation and KMeans clustering 3fd5bc8 [sboeschhuawei] PIClustering is running in new branch (up to the pseudo-eigenvector convergence step) d5aae20 [Jiang Fan] Adding Power Iteration Clustering and Suite test a3c5fbe [Jiang Fan] Adding Power Iteration Clustering
1 parent 6ee8338 commit f377431

File tree

5 files changed

+334
-0
lines changed

5 files changed

+334
-0
lines changed
Loading

docs/mllib-clustering.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ a given dataset, the algorithm returns the best clustering result).
3434
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
3535
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
3636

37+
### Power Iteration Clustering
38+
39+
Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm:
40+
41+
* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points.
42+
* calculates the principal eigenvalue and eigenvector
43+
* Clusters each of the input points according to their principal eigenvector component value
44+
45+
Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf}
46+
47+
Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation:
48+
49+
<p style="text-align: center;">
50+
<img src="img/PIClusteringFiveCirclesInputsAndOutputs.png"
51+
title="The Property Graph"
52+
alt="The Property Graph"
53+
width="50%" />
54+
<!-- Images are downsized intentionally to improve quality on retina displays -->
55+
</p>
56+
3757
### Examples
3858

3959
<div class="codetabs">

mllib/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
<artifactId>spark-sql_${scala.binary.version}</artifactId>
5151
<version>${project.version}</version>
5252
</dependency>
53+
<dependency>
54+
<groupId>org.apache.spark</groupId>
55+
<artifactId>spark-graphx_${scala.binary.version}</artifactId>
56+
<version>${project.version}</version>
57+
</dependency>
5358
<dependency>
5459
<groupId>org.jblas</groupId>
5560
<artifactId>jblas</artifactId>
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import org.apache.spark.{Logging, SparkException}
21+
import org.apache.spark.graphx._
22+
import org.apache.spark.graphx.impl.GraphImpl
23+
import org.apache.spark.mllib.linalg.Vectors
24+
import org.apache.spark.mllib.util.MLUtils
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.util.random.XORShiftRandom
27+
28+
/**
29+
* Model produced by [[PowerIterationClustering]].
30+
*
31+
* @param k number of clusters
32+
* @param assignments an RDD of (vertexID, clusterID) pairs
33+
*/
34+
class PowerIterationClusteringModel(
35+
val k: Int,
36+
val assignments: RDD[(Long, Int)]) extends Serializable
37+
38+
/**
39+
* Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and
40+
* Cohen (see http://www.icml2010.org/papers/387.pdf). From the abstract: PIC finds a very
41+
* low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise
42+
* similarity matrix of the data.
43+
*
44+
* @param k Number of clusters.
45+
* @param maxIterations Maximum number of iterations of the PIC algorithm.
46+
*/
47+
class PowerIterationClustering private[clustering] (
48+
private var k: Int,
49+
private var maxIterations: Int) extends Serializable {
50+
51+
import org.apache.spark.mllib.clustering.PowerIterationClustering._
52+
53+
/** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100}. */
54+
def this() = this(k = 2, maxIterations = 100)
55+
56+
/**
57+
* Set the number of clusters.
58+
*/
59+
def setK(k: Int): this.type = {
60+
this.k = k
61+
this
62+
}
63+
64+
/**
65+
* Set maximum number of iterations of the power iteration loop
66+
*/
67+
def setMaxIterations(maxIterations: Int): this.type = {
68+
this.maxIterations = maxIterations
69+
this
70+
}
71+
72+
/**
73+
* Run the PIC algorithm.
74+
*
75+
* @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is
76+
* the matrix A in the PIC paper. The similarity s_ij_ must be nonnegative.
77+
* This is a symmetric matrix and hence s_ij_ = s_ji_. For any (i, j) with
78+
* nonzero similarity, there should be either (i, j, s_ij_) or (j, i, s_ji_)
79+
* in the input. Tuples with i = j are ignored, because we assume s_ij_ = 0.0.
80+
*
81+
* @return a [[PowerIterationClusteringModel]] that contains the clustering result
82+
*/
83+
def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = {
84+
val w = normalize(similarities)
85+
val w0 = randomInit(w)
86+
pic(w0)
87+
}
88+
89+
/**
90+
* Runs the PIC algorithm.
91+
*
92+
* @param w The normalized affinity matrix, which is the matrix W in the PIC paper with
93+
* w_ij_ = a_ij_ / d_ii_ as its edge properties and the initial vector of the power
94+
* iteration as its vertex properties.
95+
*/
96+
private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = {
97+
val v = powerIter(w, maxIterations)
98+
val assignments = kMeans(v, k)
99+
new PowerIterationClusteringModel(k, assignments)
100+
}
101+
}
102+
103+
private[clustering] object PowerIterationClustering extends Logging {
104+
/**
105+
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
106+
*/
107+
def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = {
108+
val edges = similarities.flatMap { case (i, j, s) =>
109+
if (s < 0.0) {
110+
throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
111+
}
112+
if (i != j) {
113+
Seq(Edge(i, j, s), Edge(j, i, s))
114+
} else {
115+
None
116+
}
117+
}
118+
val gA = Graph.fromEdges(edges, 0.0)
119+
val vD = gA.aggregateMessages[Double](
120+
sendMsg = ctx => {
121+
ctx.sendToSrc(ctx.attr)
122+
},
123+
mergeMsg = _ + _,
124+
TripletFields.EdgeOnly)
125+
GraphImpl.fromExistingRDDs(vD, gA.edges)
126+
.mapTriplets(
127+
e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON),
128+
TripletFields.Src)
129+
}
130+
131+
/**
132+
* Generates random vertex properties (v0) to start power iteration.
133+
*
134+
* @param g a graph representing the normalized affinity matrix (W)
135+
* @return a graph with edges representing W and vertices representing a random vector
136+
* with unit 1-norm
137+
*/
138+
def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = {
139+
val r = g.vertices.mapPartitionsWithIndex(
140+
(part, iter) => {
141+
val random = new XORShiftRandom(part)
142+
iter.map { case (id, _) =>
143+
(id, random.nextGaussian())
144+
}
145+
}, preservesPartitioning = true).cache()
146+
val sum = r.values.map(math.abs).sum()
147+
val v0 = r.mapValues(x => x / sum)
148+
GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
149+
}
150+
151+
/**
152+
* Runs power iteration.
153+
* @param g input graph with edges representing the normalized affinity matrix (W) and vertices
154+
* representing the initial vector of the power iterations.
155+
* @param maxIterations maximum number of iterations
156+
* @return a [[VertexRDD]] representing the pseudo-eigenvector
157+
*/
158+
def powerIter(
159+
g: Graph[Double, Double],
160+
maxIterations: Int): VertexRDD[Double] = {
161+
// the default tolerance used in the PIC paper, with a lower bound 1e-8
162+
val tol = math.max(1e-5 / g.vertices.count(), 1e-8)
163+
var prevDelta = Double.MaxValue
164+
var diffDelta = Double.MaxValue
165+
var curG = g
166+
for (iter <- 0 until maxIterations if math.abs(diffDelta) > tol) {
167+
val msgPrefix = s"Iteration $iter"
168+
// multiply W by vt
169+
val v = curG.aggregateMessages[Double](
170+
sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr),
171+
mergeMsg = _ + _,
172+
TripletFields.Dst).cache()
173+
// normalize v
174+
val norm = v.values.map(math.abs).sum()
175+
logInfo(s"$msgPrefix: norm(v) = $norm.")
176+
val v1 = v.mapValues(x => x / norm)
177+
// compare difference
178+
val delta = curG.joinVertices(v1) { case (_, x, y) =>
179+
math.abs(x - y)
180+
}.vertices.values.sum()
181+
logInfo(s"$msgPrefix: delta = $delta.")
182+
diffDelta = math.abs(delta - prevDelta)
183+
logInfo(s"$msgPrefix: diff(delta) = $diffDelta.")
184+
// update v
185+
curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges)
186+
prevDelta = delta
187+
}
188+
curG.vertices
189+
}
190+
191+
/**
192+
* Runs k-means clustering.
193+
* @param v a [[VertexRDD]] representing the pseudo-eigenvector
194+
* @param k number of clusters
195+
* @return a [[VertexRDD]] representing the clustering assignments
196+
*/
197+
def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = {
198+
val points = v.mapValues(x => Vectors.dense(x)).cache()
199+
val model = new KMeans()
200+
.setK(k)
201+
.setRuns(5)
202+
.setSeed(0L)
203+
.run(points.values)
204+
points.mapValues(p => model.predict(p)).cache()
205+
}
206+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import scala.collection.mutable
21+
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.graphx.{Edge, Graph}
25+
import org.apache.spark.mllib.util.MLlibTestSparkContext
26+
import org.apache.spark.mllib.util.TestingUtils._
27+
28+
class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
29+
30+
import org.apache.spark.mllib.clustering.PowerIterationClustering._
31+
32+
test("power iteration clustering") {
33+
/*
34+
We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for
35+
edge (3, 4).
36+
37+
15-14 -13 -12
38+
| |
39+
4 . 3 - 2 11
40+
| | x | |
41+
5 0 - 1 10
42+
| |
43+
6 - 7 - 8 - 9
44+
*/
45+
46+
val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0),
47+
(1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge
48+
(4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0),
49+
(10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0))
50+
val model = new PowerIterationClustering()
51+
.setK(2)
52+
.run(sc.parallelize(similarities, 2))
53+
val predictions = Array.fill(2)(mutable.Set.empty[Long])
54+
model.assignments.collect().foreach { case (i, c) =>
55+
predictions(c) += i
56+
}
57+
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
58+
}
59+
60+
test("normalize and powerIter") {
61+
/*
62+
Test normalize() with the following graph:
63+
64+
0 - 3
65+
| \ |
66+
1 - 2
67+
68+
The affinity matrix (A) is
69+
70+
0 1 1 1
71+
1 0 1 0
72+
1 1 0 1
73+
1 0 1 0
74+
75+
D is diag(3, 2, 3, 2) and hence W is
76+
77+
0 1/3 1/3 1/3
78+
1/2 0 1/2 0
79+
1/3 1/3 0 1/3
80+
1/2 0 1/2 0
81+
*/
82+
val similarities = Seq[(Long, Long, Double)](
83+
(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0))
84+
val expected = Array(
85+
Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0),
86+
Array(1.0/2.0, 0.0, 1.0/2.0, 0.0),
87+
Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0),
88+
Array(1.0/2.0, 0.0, 1.0/2.0, 0.0))
89+
val w = normalize(sc.parallelize(similarities, 2))
90+
w.edges.collect().foreach { case Edge(i, j, x) =>
91+
assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14)
92+
}
93+
val v0 = sc.parallelize(Seq[(Long, Double)]((0, 0.1), (1, 0.2), (2, 0.3), (3, 0.4)), 2)
94+
val w0 = Graph(v0, w.edges)
95+
val v1 = powerIter(w0, maxIterations = 1).collect()
96+
val u = Array(0.3, 0.2, 0.7/3.0, 0.2)
97+
val norm = u.sum
98+
val u1 = u.map(x => x / norm)
99+
v1.foreach { case (i, x) =>
100+
assert(x ~== u1(i.toInt) absTol 1e-14)
101+
}
102+
}
103+
}

0 commit comments

Comments
 (0)