diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 5a8e5bb1f721a..e8124d225fedd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -212,8 +212,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JIterable[T]] = { - implicit val ctagK: ClassTag[K] = fakeClassTag + def groupBy[U](f: JFunction[T, U]): JavaPairRDD[U, JIterable[T]] = { + implicit val ctagK: ClassTag[U] = fakeClassTag implicit val ctagV: ClassTag[JList[T]] = fakeClassTag JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag))) } @@ -222,10 +222,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JIterable[T]] = { - implicit val ctagK: ClassTag[K] = fakeClassTag + def groupBy[U](f: JFunction[T, U], numPartitions: Int): JavaPairRDD[U, JIterable[T]] = { + implicit val ctagK: ClassTag[U] = fakeClassTag implicit val ctagV: ClassTag[JList[T]] = fakeClassTag - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K]))) + JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[U]))) } /** @@ -459,8 +459,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Creates tuples of the elements in this RDD by applying `f`. */ - def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { - implicit val ctag: ClassTag[K] = fakeClassTag + def keyBy[U](f: JFunction[T, U]): JavaPairRDD[U, T] = { + implicit val ctag: ClassTag[U] = fakeClassTag JavaPairRDD.fromRDD(rdd.keyBy(f)) } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 59c86eecac5e8..95782f2627221 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -322,6 +322,42 @@ public Boolean call(Integer x) { Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds } + + @Test + public void groupByOnPairRDD() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function, Boolean> areOdd = new Function, Boolean>() { + @Override + public Boolean call(scala.Tuple2 x) { + return x._1 % 2 == 0 && x._2 % 2 == 0; + } + }; + JavaPairRDD pairrdd = rdd.zip(rdd); + JavaPairRDD>> oddsAndEvens = pairrdd.groupBy(areOdd); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + + oddsAndEvens = pairrdd.groupBy(areOdd, 1); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + } + + @Test + public void keyByOnPairRDD() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function, String> areOdd = new Function, String>() { + @Override + public String call(scala.Tuple2 x) { + return ""+(x._1 +x._2); + } + }; + JavaPairRDD pairrdd = rdd.zip(rdd); + JavaPairRDD> keyed = pairrdd.keyBy(areOdd); + Assert.assertEquals(7, keyed.count()); + Assert.assertEquals(1, (long)keyed.lookup("2").get(0)._1); + } @SuppressWarnings("unchecked") @Test