Skip to content

Commit faeb41d

Browse files
ankurdaverxin
authored andcommitted
[SPARK-3936] Add aggregateMessages, which supersedes mapReduceTriplets
aggregateMessages enables neighborhood computation similarly to mapReduceTriplets, but it introduces two API improvements: 1. Messages are sent using an imperative interface based on EdgeContext rather than by returning an iterator of messages. 2. Rather than attempting bytecode inspection, the required triplet fields must be explicitly specified by the user by passing a TripletFields object. This fixes SPARK-3936. Additionally, this PR includes the following optimizations for aggregateMessages and EdgePartition: 1. EdgePartition now stores local vertex ids instead of global ids. This avoids hash lookups when looking up vertex attributes and aggregating messages. 2. Internal iterators in aggregateMessages are inlined into a while loop. In total, these optimizations were tested to provide a 37% speedup on PageRank (uk-2007-05 graph, 10 iterations, 16 r3.2xlarge machines, sped up from 513 s to 322 s). Subsumes apache#2815. Also fixes SPARK-4173. Author: Ankur Dave <[email protected]> Closes apache#3100 from ankurdave/aggregateMessages and squashes the following commits: f5b65d0 [Ankur Dave] Address @rxin comments on apache#3054 and apache#3100 1e80aca [Ankur Dave] Add aggregateMessages, which supersedes mapReduceTriplets 194a2df [Ankur Dave] Test triplet iterator in EdgePartition serialization test e0f8ecc [Ankur Dave] Take activeSet in ExistingEdgePartitionBuilder c85076d [Ankur Dave] Readability improvements b567be2 [Ankur Dave] iter.foreach -> while loop 4a566dc [Ankur Dave] Optimizations for mapReduceTriplets and EdgePartition
1 parent 2ef016b commit faeb41d

15 files changed

+766
-376
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.graphx
19+
20+
/**
21+
* Represents an edge along with its neighboring vertices and allows sending messages along the
22+
* edge. Used in [[Graph#aggregateMessages]].
23+
*/
24+
abstract class EdgeContext[VD, ED, A] {
25+
/** The vertex id of the edge's source vertex. */
26+
def srcId: VertexId
27+
/** The vertex id of the edge's destination vertex. */
28+
def dstId: VertexId
29+
/** The vertex attribute of the edge's source vertex. */
30+
def srcAttr: VD
31+
/** The vertex attribute of the edge's destination vertex. */
32+
def dstAttr: VD
33+
/** The attribute associated with the edge. */
34+
def attr: ED
35+
36+
/** Sends a message to the source vertex. */
37+
def sendToSrc(msg: A): Unit
38+
/** Sends a message to the destination vertex. */
39+
def sendToDst(msg: A): Unit
40+
41+
/** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */
42+
def toEdgeTriplet: EdgeTriplet[VD, ED] = {
43+
val et = new EdgeTriplet[VD, ED]
44+
et.srcId = srcId
45+
et.srcAttr = srcAttr
46+
et.dstId = dstId
47+
et.dstAttr = dstAttr
48+
et.attr = attr
49+
et
50+
}
51+
}

graphx/src/main/scala/org/apache/spark/graphx/Graph.scala

Lines changed: 124 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,39 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
207207
* }}}
208208
*
209209
*/
210-
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
211-
mapTriplets((pid, iter) => iter.map(map))
210+
def mapTriplets[ED2: ClassTag](
211+
map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
212+
mapTriplets((pid, iter) => iter.map(map), TripletFields.All)
213+
}
214+
215+
/**
216+
* Transforms each edge attribute using the map function, passing it the adjacent vertex
217+
* attributes as well. If adjacent vertex values are not required,
218+
* consider using `mapEdges` instead.
219+
*
220+
* @note This does not change the structure of the
221+
* graph or modify the values of this graph. As a consequence
222+
* the underlying index structures can be reused.
223+
*
224+
* @param map the function from an edge object to a new edge value.
225+
* @param tripletFields which fields should be included in the edge triplet passed to the map
226+
* function. If not all fields are needed, specifying this can improve performance.
227+
*
228+
* @tparam ED2 the new edge data type
229+
*
230+
* @example This function might be used to initialize edge
231+
* attributes based on the attributes associated with each vertex.
232+
* {{{
233+
* val rawGraph: Graph[Int, Int] = someLoadFunction()
234+
* val graph = rawGraph.mapTriplets[Int]( edge =>
235+
* edge.src.data - edge.dst.data)
236+
* }}}
237+
*
238+
*/
239+
def mapTriplets[ED2: ClassTag](
240+
map: EdgeTriplet[VD, ED] => ED2,
241+
tripletFields: TripletFields): Graph[VD, ED2] = {
242+
mapTriplets((pid, iter) => iter.map(map), tripletFields)
212243
}
213244

214245
/**
@@ -223,12 +254,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
223254
* the underlying index structures can be reused.
224255
*
225256
* @param map the iterator transform
257+
* @param tripletFields which fields should be included in the edge triplet passed to the map
258+
* function. If not all fields are needed, specifying this can improve performance.
226259
*
227260
* @tparam ED2 the new edge data type
228261
*
229262
*/
230-
def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
231-
: Graph[VD, ED2]
263+
def mapTriplets[ED2: ClassTag](
264+
map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
265+
tripletFields: TripletFields): Graph[VD, ED2]
232266

233267
/**
234268
* Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
@@ -287,6 +321,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
287321
* "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of
288322
* the map phase destined to each vertex.
289323
*
324+
* This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead.
325+
*
290326
* @tparam A the type of "message" to be sent to each vertex
291327
*
292328
* @param mapFunc the user defined map function which returns 0 or
@@ -296,13 +332,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
296332
* be commutative and associative and is used to combine the output
297333
* of the map phase
298334
*
299-
* @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to
300-
* consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on
301-
* edges with destination in the active set. If the direction is `Out`,
302-
* `mapFunc` will only be run on edges originating from vertices in the active set. If the
303-
* direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set
304-
* . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the
305-
* active set. The active set must have the same index as the graph's vertices.
335+
* @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
336+
* desired. This is done by specifying a set of "active" vertices and an edge direction. The
337+
* `sendMsg` function will then run only on edges connected to active vertices by edges in the
338+
* specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
339+
* destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
340+
* originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
341+
* run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
342+
* will be run on edges with *both* vertices in the active set. The active set must have the
343+
* same index as the graph's vertices.
306344
*
307345
* @example We can use this function to compute the in-degree of each
308346
* vertex
@@ -319,15 +357,88 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
319357
* predicate or implement PageRank.
320358
*
321359
*/
360+
@deprecated("use aggregateMessages", "1.2.0")
322361
def mapReduceTriplets[A: ClassTag](
323362
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
324363
reduceFunc: (A, A) => A,
325364
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
326365
: VertexRDD[A]
327366

328367
/**
329-
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
330-
* input table should contain at most one entry for each vertex. If no entry in `other` is
368+
* Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
369+
* `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
370+
* sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
371+
* destined to the same vertex.
372+
*
373+
* @tparam A the type of message to be sent to each vertex
374+
*
375+
* @param sendMsg runs on each edge, sending messages to neighboring vertices using the
376+
* [[EdgeContext]].
377+
* @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
378+
* combiner should be commutative and associative.
379+
* @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
380+
* `sendMsg` function. If not all fields are needed, specifying this can improve performance.
381+
*
382+
* @example We can use this function to compute the in-degree of each
383+
* vertex
384+
* {{{
385+
* val rawGraph: Graph[_, _] = Graph.textFile("twittergraph")
386+
* val inDeg: RDD[(VertexId, Int)] =
387+
* aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _)
388+
* }}}
389+
*
390+
* @note By expressing computation at the edge level we achieve
391+
* maximum parallelism. This is one of the core functions in the
392+
* Graph API in that enables neighborhood level computation. For
393+
* example this function can be used to count neighbors satisfying a
394+
* predicate or implement PageRank.
395+
*
396+
*/
397+
def aggregateMessages[A: ClassTag](
398+
sendMsg: EdgeContext[VD, ED, A] => Unit,
399+
mergeMsg: (A, A) => A,
400+
tripletFields: TripletFields = TripletFields.All)
401+
: VertexRDD[A] = {
402+
aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None)
403+
}
404+
405+
/**
406+
* Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
407+
* `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
408+
* sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
409+
* destined to the same vertex.
410+
*
411+
* This variant can take an active set to restrict the computation and is intended for internal
412+
* use only.
413+
*
414+
* @tparam A the type of message to be sent to each vertex
415+
*
416+
* @param sendMsg runs on each edge, sending messages to neighboring vertices using the
417+
* [[EdgeContext]].
418+
* @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
419+
* combiner should be commutative and associative.
420+
* @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
421+
* `sendMsg` function. If not all fields are needed, specifying this can improve performance.
422+
* @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
423+
* desired. This is done by specifying a set of "active" vertices and an edge direction. The
424+
* `sendMsg` function will then run on only edges connected to active vertices by edges in the
425+
* specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
426+
* destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
427+
* originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
428+
* run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
429+
* will be run on edges with *both* vertices in the active set. The active set must have the
430+
* same index as the graph's vertices.
431+
*/
432+
private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag](
433+
sendMsg: EdgeContext[VD, ED, A] => Unit,
434+
mergeMsg: (A, A) => A,
435+
tripletFields: TripletFields,
436+
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)])
437+
: VertexRDD[A]
438+
439+
/**
440+
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.
441+
* The input table should contain at most one entry for each vertex. If no entry in `other` is
331442
* provided for a particular vertex in the graph, the map function receives `None`.
332443
*
333444
* @tparam U the type of entry in the table of updates

graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
6969
*/
7070
private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
7171
if (edgeDirection == EdgeDirection.In) {
72-
graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
72+
graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None)
7373
} else if (edgeDirection == EdgeDirection.Out) {
74-
graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
74+
graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None)
7575
} else { // EdgeDirection.Either
76-
graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
76+
graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _,
77+
TripletFields.None)
7778
}
7879
}
7980

@@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
8889
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
8990
val nbrs =
9091
if (edgeDirection == EdgeDirection.Either) {
91-
graph.mapReduceTriplets[Array[VertexId]](
92-
mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
93-
reduceFunc = _ ++ _
94-
)
92+
graph.aggregateMessages[Array[VertexId]](
93+
ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
94+
_ ++ _, TripletFields.None)
9595
} else if (edgeDirection == EdgeDirection.Out) {
96-
graph.mapReduceTriplets[Array[VertexId]](
97-
mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
98-
reduceFunc = _ ++ _)
96+
graph.aggregateMessages[Array[VertexId]](
97+
ctx => ctx.sendToSrc(Array(ctx.dstId)),
98+
_ ++ _, TripletFields.None)
9999
} else if (edgeDirection == EdgeDirection.In) {
100-
graph.mapReduceTriplets[Array[VertexId]](
101-
mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
102-
reduceFunc = _ ++ _)
100+
graph.aggregateMessages[Array[VertexId]](
101+
ctx => ctx.sendToDst(Array(ctx.srcId)),
102+
_ ++ _, TripletFields.None)
103103
} else {
104104
throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
105105
"direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
@@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
122122
* @return the vertex set of neighboring vertex attributes for each vertex
123123
*/
124124
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
125-
val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]](
126-
edge => {
127-
val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
128-
val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
129-
edgeDirection match {
130-
case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
131-
case EdgeDirection.In => Iterator(msgToDst)
132-
case EdgeDirection.Out => Iterator(msgToSrc)
133-
case EdgeDirection.Both =>
134-
throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
135-
"EdgeDirection.Either instead.")
136-
}
137-
},
138-
(a, b) => a ++ b)
139-
140-
graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
125+
val nbrs = edgeDirection match {
126+
case EdgeDirection.Either =>
127+
graph.aggregateMessages[Array[(VertexId,VD)]](
128+
ctx => {
129+
ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
130+
ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
131+
},
132+
(a, b) => a ++ b, TripletFields.SrcDstOnly)
133+
case EdgeDirection.In =>
134+
graph.aggregateMessages[Array[(VertexId,VD)]](
135+
ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
136+
(a, b) => a ++ b, TripletFields.SrcOnly)
137+
case EdgeDirection.Out =>
138+
graph.aggregateMessages[Array[(VertexId,VD)]](
139+
ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
140+
(a, b) => a ++ b, TripletFields.DstOnly)
141+
case EdgeDirection.Both =>
142+
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
143+
"EdgeDirection.Either instead.")
144+
}
145+
graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
141146
nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
142147
}
143148
} // end of collectNeighbor
@@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
160165
def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
161166
edgeDirection match {
162167
case EdgeDirection.Either =>
163-
graph.mapReduceTriplets[Array[Edge[ED]]](
164-
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
165-
(edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
166-
(a, b) => a ++ b)
168+
graph.aggregateMessages[Array[Edge[ED]]](
169+
ctx => {
170+
ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
171+
ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
172+
},
173+
(a, b) => a ++ b, TripletFields.EdgeOnly)
167174
case EdgeDirection.In =>
168-
graph.mapReduceTriplets[Array[Edge[ED]]](
169-
edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
170-
(a, b) => a ++ b)
175+
graph.aggregateMessages[Array[Edge[ED]]](
176+
ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
177+
(a, b) => a ++ b, TripletFields.EdgeOnly)
171178
case EdgeDirection.Out =>
172-
graph.mapReduceTriplets[Array[Edge[ED]]](
173-
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
174-
(a, b) => a ++ b)
179+
graph.aggregateMessages[Array[Edge[ED]]](
180+
ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
181+
(a, b) => a ++ b, TripletFields.EdgeOnly)
175182
case EdgeDirection.Both =>
176183
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
177184
"EdgeDirection.Either instead.")

0 commit comments

Comments
 (0)