Skip to content

Commit 6a224c3

Browse files
douglazpwendell
authored andcommitted
SPARK-1868: Users should be allowed to cogroup at least 4 RDDs
Adds cogroup for 4 RDDs. Author: Allan Douglas R. de Oliveira <[email protected]> Closes apache#813 from douglaz/more_cogroups and squashes the following commits: f8d6273 [Allan Douglas R. de Oliveira] Test python groupWith for one more case 0e9009c [Allan Douglas R. de Oliveira] Added scala tests c3ffcdd [Allan Douglas R. de Oliveira] Added java tests 517a67f [Allan Douglas R. de Oliveira] Added tests for python groupWith 2f402d5 [Allan Douglas R. de Oliveira] Removed TODO 17474f4 [Allan Douglas R. de Oliveira] Use new cogroup function 7877a2a [Allan Douglas R. de Oliveira] Fixed code ba02414 [Allan Douglas R. de Oliveira] Added varargs cogroup to pyspark c4a8a51 [Allan Douglas R. de Oliveira] Added java cogroup 4 e94963c [Allan Douglas R. de Oliveira] Fixed spacing f1ee57b [Allan Douglas R. de Oliveira] Fixed scala style issues d7196f1 [Allan Douglas R. de Oliveira] Allow the cogroup of 4 RDDs
1 parent d484dde commit 6a224c3

File tree

6 files changed

+223
-17
lines changed

6 files changed

+223
-17
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
543543
partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
544544
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
545545

546+
/**
547+
* For each key k in `this` or `other1` or `other2` or `other3`,
548+
* return a resulting RDD that contains a tuple with the list of values
549+
* for that key in `this`, `other1`, `other2` and `other3`.
550+
*/
551+
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
552+
other2: JavaPairRDD[K, W2],
553+
other3: JavaPairRDD[K, W3],
554+
partitioner: Partitioner)
555+
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
556+
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner)))
557+
546558
/**
547559
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
548560
* list of values for that key in `this` as well as `other`.
@@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
558570
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
559571
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
560572

573+
/**
574+
* For each key k in `this` or `other1` or `other2` or `other3`,
575+
* return a resulting RDD that contains a tuple with the list of values
576+
* for that key in `this`, `other1`, `other2` and `other3`.
577+
*/
578+
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
579+
other2: JavaPairRDD[K, W2],
580+
other3: JavaPairRDD[K, W3])
581+
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
582+
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3)))
583+
561584
/**
562585
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
563586
* list of values for that key in `this` as well as `other`.
@@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
574597
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
575598
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
576599

600+
/**
601+
* For each key k in `this` or `other1` or `other2` or `other3`,
602+
* return a resulting RDD that contains a tuple with the list of values
603+
* for that key in `this`, `other1`, `other2` and `other3`.
604+
*/
605+
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
606+
other2: JavaPairRDD[K, W2],
607+
other3: JavaPairRDD[K, W3],
608+
numPartitions: Int)
609+
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
610+
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions)))
611+
577612
/** Alias for cogroup. */
578613
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] =
579614
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
@@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
583618
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
584619
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
585620

621+
/** Alias for cogroup. */
622+
def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1],
623+
other2: JavaPairRDD[K, W2],
624+
other3: JavaPairRDD[K, W3])
625+
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
626+
fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3)))
627+
586628
/**
587629
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
588630
* RDD has a known partitioner by only searching the partition that the key maps to.
@@ -786,6 +828,15 @@ object JavaPairRDD {
786828
.mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3)))
787829
}
788830

831+
private[spark]
832+
def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3](
833+
rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))])
834+
: RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = {
835+
rddToPairRDDFunctions(rdd)
836+
.mapValues(x =>
837+
(asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4)))
838+
}
839+
789840
def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = {
790841
new JavaPairRDD[K, V](rdd)
791842
}

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
567567
new FlatMappedValuesRDD(self, cleanF)
568568
}
569569

570+
/**
571+
* For each key k in `this` or `other1` or `other2` or `other3`,
572+
* return a resulting RDD that contains a tuple with the list of values
573+
* for that key in `this`, `other1`, `other2` and `other3`.
574+
*/
575+
def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
576+
other2: RDD[(K, W2)],
577+
other3: RDD[(K, W3)],
578+
partitioner: Partitioner)
579+
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
580+
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
581+
throw new SparkException("Default partitioner cannot partition array keys.")
582+
}
583+
val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner)
584+
cg.mapValues { case Seq(vs, w1s, w2s, w3s) =>
585+
(vs.asInstanceOf[Seq[V]],
586+
w1s.asInstanceOf[Seq[W1]],
587+
w2s.asInstanceOf[Seq[W2]],
588+
w3s.asInstanceOf[Seq[W3]])
589+
}
590+
}
591+
570592
/**
571593
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
572594
* list of values for that key in `this` as well as `other`.
@@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
599621
}
600622
}
601623

624+
/**
625+
* For each key k in `this` or `other1` or `other2` or `other3`,
626+
* return a resulting RDD that contains a tuple with the list of values
627+
* for that key in `this`, `other1`, `other2` and `other3`.
628+
*/
629+
def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
630+
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
631+
cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
632+
}
633+
602634
/**
603635
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
604636
* list of values for that key in `this` as well as `other`.
@@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
633665
cogroup(other1, other2, new HashPartitioner(numPartitions))
634666
}
635667

668+
/**
669+
* For each key k in `this` or `other1` or `other2` or `other3`,
670+
* return a resulting RDD that contains a tuple with the list of values
671+
* for that key in `this`, `other1`, `other2` and `other3`.
672+
*/
673+
def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
674+
other2: RDD[(K, W2)],
675+
other3: RDD[(K, W3)],
676+
numPartitions: Int)
677+
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
678+
cogroup(other1, other2, other3, new HashPartitioner(numPartitions))
679+
}
680+
636681
/** Alias for cogroup. */
637682
def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = {
638683
cogroup(other, defaultPartitioner(self, other))
@@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
644689
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
645690
}
646691

692+
/** Alias for cogroup. */
693+
def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
694+
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
695+
cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
696+
}
697+
647698
/**
648699
* Return an RDD with the pairs from `this` whose keys are not in `other`.
649700
*

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import java.util.*;
2222

2323
import scala.Tuple2;
24+
import scala.Tuple3;
25+
import scala.Tuple4;
26+
2427

2528
import com.google.common.collect.Iterables;
2629
import com.google.common.collect.Iterators;
@@ -304,6 +307,66 @@ public void cogroup() {
304307
cogrouped.collect();
305308
}
306309

310+
@SuppressWarnings("unchecked")
311+
@Test
312+
public void cogroup3() {
313+
JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
314+
new Tuple2<String, String>("Apples", "Fruit"),
315+
new Tuple2<String, String>("Oranges", "Fruit"),
316+
new Tuple2<String, String>("Oranges", "Citrus")
317+
));
318+
JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
319+
new Tuple2<String, Integer>("Oranges", 2),
320+
new Tuple2<String, Integer>("Apples", 3)
321+
));
322+
JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
323+
new Tuple2<String, Integer>("Oranges", 21),
324+
new Tuple2<String, Integer>("Apples", 42)
325+
));
326+
327+
JavaPairRDD<String, Tuple3<Iterable<String>, Iterable<Integer>, Iterable<Integer>>> cogrouped =
328+
categories.cogroup(prices, quantities);
329+
Assert.assertEquals("[Fruit, Citrus]",
330+
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
331+
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
332+
Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
333+
334+
335+
cogrouped.collect();
336+
}
337+
338+
@SuppressWarnings("unchecked")
339+
@Test
340+
public void cogroup4() {
341+
JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
342+
new Tuple2<String, String>("Apples", "Fruit"),
343+
new Tuple2<String, String>("Oranges", "Fruit"),
344+
new Tuple2<String, String>("Oranges", "Citrus")
345+
));
346+
JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
347+
new Tuple2<String, Integer>("Oranges", 2),
348+
new Tuple2<String, Integer>("Apples", 3)
349+
));
350+
JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
351+
new Tuple2<String, Integer>("Oranges", 21),
352+
new Tuple2<String, Integer>("Apples", 42)
353+
));
354+
JavaPairRDD<String, String> countries = sc.parallelizePairs(Arrays.asList(
355+
new Tuple2<String, String>("Oranges", "BR"),
356+
new Tuple2<String, String>("Apples", "US")
357+
));
358+
359+
JavaPairRDD<String, Tuple4<Iterable<String>, Iterable<Integer>, Iterable<Integer>, Iterable<String>>> cogrouped =
360+
categories.cogroup(prices, quantities, countries);
361+
Assert.assertEquals("[Fruit, Citrus]",
362+
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
363+
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
364+
Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
365+
Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));
366+
367+
cogrouped.collect();
368+
}
369+
307370
@SuppressWarnings("unchecked")
308371
@Test
309372
public void leftOuterJoin() {

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
249249
))
250250
}
251251

252+
test("groupWith3") {
253+
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
254+
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
255+
val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
256+
val joined = rdd1.groupWith(rdd2, rdd3).collect()
257+
assert(joined.size === 4)
258+
val joinedSet = joined.map(x => (x._1,
259+
(x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet
260+
assert(joinedSet === Set(
261+
(1, (List(1, 2), List('x'), List('a'))),
262+
(2, (List(1), List('y', 'z'), List())),
263+
(3, (List(1), List(), List('b'))),
264+
(4, (List(), List('w'), List('c', 'd')))
265+
))
266+
}
267+
268+
test("groupWith4") {
269+
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
270+
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
271+
val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
272+
val rdd4 = sc.parallelize(Array((2, '@')))
273+
val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect()
274+
assert(joined.size === 4)
275+
val joinedSet = joined.map(x => (x._1,
276+
(x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet
277+
assert(joinedSet === Set(
278+
(1, (List(1, 2), List('x'), List('a'), List())),
279+
(2, (List(1), List('y', 'z'), List(), List('@'))),
280+
(3, (List(1), List(), List('b'), List())),
281+
(4, (List(), List('w'), List('c', 'd'), List()))
282+
))
283+
}
284+
252285
test("zero-partition RDD") {
253286
val emptyDir = Files.createTempDir()
254287
emptyDir.deleteOnExit()

python/pyspark/join.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@ def dispatch(seq):
7979
return _do_python_join(rdd, other, numPartitions, dispatch)
8080

8181

82-
def python_cogroup(rdd, other, numPartitions):
83-
vs = rdd.map(lambda (k, v): (k, (1, v)))
84-
ws = other.map(lambda (k, v): (k, (2, v)))
82+
def python_cogroup(rdds, numPartitions):
83+
def make_mapper(i):
84+
return lambda (k, v): (k, (i, v))
85+
vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
86+
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
87+
rdd_len = len(vrdds)
8588
def dispatch(seq):
86-
vbuf, wbuf = [], []
89+
bufs = [[] for i in range(rdd_len)]
8790
for (n, v) in seq:
88-
if n == 1:
89-
vbuf.append(v)
90-
elif n == 2:
91-
wbuf.append(v)
92-
return (ResultIterable(vbuf), ResultIterable(wbuf))
93-
return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
91+
bufs[n].append(v)
92+
return tuple(map(ResultIterable, bufs))
93+
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)

python/pyspark/rdd.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator):
12331233
combiners[k] = mergeCombiners(combiners[k], v)
12341234
return combiners.iteritems()
12351235
return shuffled.mapPartitions(_mergeCombiners)
1236-
1236+
12371237
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
12381238
"""
12391239
Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
12451245
"""
12461246
def createZero():
12471247
return copy.deepcopy(zeroValue)
1248-
1248+
12491249
return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
12501250

12511251
def foldByKey(self, zeroValue, func, numPartitions=None):
@@ -1323,12 +1323,20 @@ def mapValues(self, f):
13231323
map_values_fn = lambda (k, v): (k, f(v))
13241324
return self.map(map_values_fn, preservesPartitioning=True)
13251325

1326-
# TODO: support varargs cogroup of several RDDs.
1327-
def groupWith(self, other):
1326+
def groupWith(self, other, *others):
13281327
"""
1329-
Alias for cogroup.
1328+
Alias for cogroup but with support for multiple RDDs.
1329+
1330+
>>> w = sc.parallelize([("a", 5), ("b", 6)])
1331+
>>> x = sc.parallelize([("a", 1), ("b", 4)])
1332+
>>> y = sc.parallelize([("a", 2)])
1333+
>>> z = sc.parallelize([("b", 42)])
1334+
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \
1335+
sorted(list(w.groupWith(x, y, z).collect())))
1336+
[('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
1337+
13301338
"""
1331-
return self.cogroup(other)
1339+
return python_cogroup((self, other) + others, numPartitions=None)
13321340

13331341
# TODO: add variant with custom parittioner
13341342
def cogroup(self, other, numPartitions=None):
@@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None):
13421350
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
13431351
[('a', ([1], [2])), ('b', ([4], []))]
13441352
"""
1345-
return python_cogroup(self, other, numPartitions)
1353+
return python_cogroup((self, other), numPartitions)
13461354

13471355
def subtractByKey(self, other, numPartitions=None):
13481356
"""

0 commit comments

Comments
 (0)