Skip to content

Commit 63073d0

Browse files
committed
update to make generic FPGrowth Java-friendly
1 parent 737d8bb commit 63073d0

File tree

3 files changed

+86
-87
lines changed

3 files changed

+86
-87
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,30 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20-
import java.lang.{Iterable => JavaIterable}
2120
import java.{util => ju}
21+
import java.lang.{Iterable => JavaIterable}
2222

23-
import scala.collection.JavaConverters._
2423
import scala.collection.mutable
24+
import scala.collection.JavaConverters._
2525
import scala.reflect.ClassTag
2626

27-
import org.apache.spark.api.java.JavaRDD
27+
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
28+
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
29+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
2830
import org.apache.spark.rdd.RDD
2931
import org.apache.spark.storage.StorageLevel
30-
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
3132

32-
class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
33-
def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = {
34-
freqItemsets.toJavaRDD()
33+
/**
34+
* Model trained by [[FPGrowth]], which holds frequent itemsets.
35+
* @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
36+
* @tparam Item item type
37+
*/
38+
class FPGrowthModel[Item: ClassTag](
39+
val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
40+
41+
/** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
42+
def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
43+
JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
3544
}
3645
}
3746

@@ -77,22 +86,22 @@ class FPGrowth private (
7786
* @param data input data set, each element contains a transaction
7887
* @return an [[FPGrowthModel]]
7988
*/
80-
def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = {
89+
def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
8190
if (data.getStorageLevel == StorageLevel.NONE) {
8291
logWarning("Input data is not cached.")
8392
}
8493
val count = data.count()
8594
val minCount = math.ceil(minSupport * count).toLong
8695
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
8796
val partitioner = new HashPartitioner(numParts)
88-
val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner)
89-
val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner)
97+
val freqItems = genFreqItems(data, minCount, partitioner)
98+
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
9099
new FPGrowthModel(freqItemsets)
91100
}
92101

93-
def run[Item: ClassTag, Basket <: JavaIterable[Item]](
94-
data: JavaRDD[Basket]): FPGrowthModel[Item] = {
95-
this.run(data.rdd.map(_.asScala))
102+
def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
103+
implicit val tag = fakeClassTag[Item]
104+
run(data.rdd.map(_.asScala.toArray))
96105
}
97106

98107
/**
@@ -101,8 +110,8 @@ class FPGrowth private (
101110
* @param partitioner partitioner used to distribute items
102111
* @return array of frequent pattern ordered by their frequencies
103112
*/
104-
private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]](
105-
data: RDD[Basket],
113+
private def genFreqItems[Item: ClassTag](
114+
data: RDD[Array[Item]],
106115
minCount: Long,
107116
partitioner: Partitioner): Array[Item] = {
108117
data.flatMap { t =>
@@ -127,8 +136,8 @@ class FPGrowth private (
127136
* @param partitioner partitioner used to distribute transactions
128137
* @return an RDD of (frequent itemset, count)
129138
*/
130-
private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]](
131-
data: RDD[Basket],
139+
private def genFreqItemsets[Item: ClassTag](
140+
data: RDD[Array[Item]],
132141
minCount: Long,
133142
freqItems: Array[Item],
134143
partitioner: Partitioner): RDD[(Array[Item], Long)] = {
@@ -152,13 +161,13 @@ class FPGrowth private (
152161
* @param partitioner partitioner used to distribute transactions
153162
* @return a map of (target partition, conditional transaction)
154163
*/
155-
private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]](
156-
transaction: Basket,
164+
private def genCondTransactions[Item: ClassTag](
165+
transaction: Array[Item],
157166
itemToRank: Map[Item, Int],
158167
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
159168
val output = mutable.Map.empty[Int, Array[Int]]
160169
// Filter the basket by frequent items pattern and sort their ranks.
161-
val filtered = transaction.flatMap(itemToRank.get).toArray
170+
val filtered = transaction.flatMap(itemToRank.get)
162171
ju.Arrays.sort(filtered)
163172
val n = filtered.length
164173
var i = n - 1

mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java

Lines changed: 45 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,78 +19,66 @@
1919

2020
import java.io.Serializable;
2121
import java.util.ArrayList;
22-
import java.util.List;
2322

2423
import org.junit.After;
2524
import org.junit.Before;
2625
import org.junit.Test;
27-
import static org.junit.Assert.*;
28-
2926
import com.google.common.collect.Lists;
27+
import static org.junit.Assert.*;
3028

3129
import org.apache.spark.api.java.JavaRDD;
3230
import org.apache.spark.api.java.JavaSparkContext;
3331

3432
public class JavaFPGrowthSuite implements Serializable {
35-
private transient JavaSparkContext sc;
33+
private transient JavaSparkContext sc;
34+
35+
@Before
36+
public void setUp() {
37+
sc = new JavaSparkContext("local", "JavaFPGrowth");
38+
}
3639

37-
@Before
38-
public void setUp() {
39-
sc = new JavaSparkContext("local", "JavaFPGrowth");
40-
}
40+
@After
41+
public void tearDown() {
42+
sc.stop();
43+
sc = null;
44+
}
4145

42-
@After
43-
public void tearDown() {
44-
sc.stop();
45-
sc = null;
46-
}
46+
@Test
47+
public void runFPGrowth() {
4748

48-
@Test
49-
public void runFPGrowth() {
50-
JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
51-
Lists.newArrayList("r z h k p".split(" ")),
52-
Lists.newArrayList("z y x w v u t s".split(" ")),
53-
Lists.newArrayList("s x o n r".split(" ")),
54-
Lists.newArrayList("x z y m t s q e".split(" ")),
55-
Lists.newArrayList("z".split(" ")),
56-
Lists.newArrayList("x z y r q t p".split(" "))), 2);
49+
@SuppressWarnings("unchecked")
50+
JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
51+
Lists.newArrayList("r z h k p".split(" ")),
52+
Lists.newArrayList("z y x w v u t s".split(" ")),
53+
Lists.newArrayList("s x o n r".split(" ")),
54+
Lists.newArrayList("x z y m t s q e".split(" ")),
55+
Lists.newArrayList("z".split(" ")),
56+
Lists.newArrayList("x z y r q t p".split(" "))), 2);
5757

58-
FPGrowth fpg = new FPGrowth();
58+
FPGrowth fpg = new FPGrowth();
5959

60-
/*
61-
FPGrowthModel model6 = fpg
62-
.setMinSupport(0.9)
63-
.setNumPartitions(1)
64-
.run(rdd);
65-
assert(model6.javaFreqItemsets().count() == 0);
60+
FPGrowthModel<String> model6 = fpg
61+
.setMinSupport(0.9)
62+
.setNumPartitions(1)
63+
.run(rdd);
64+
assertEquals(0, model6.javaFreqItemsets().count());
6665

67-
FPGrowthModel model3 = fpg
68-
.setMinSupport(0.5)
69-
.setNumPartitions(2)
70-
.run(rdd);
71-
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
72-
(items.toSet, count)
73-
}
74-
val expected = Set(
75-
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
76-
(Set("r"), 3L),
77-
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
78-
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
79-
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
80-
(Set("t", "y", "x"), 3L),
81-
(Set("t", "y", "x", "z"), 3L))
82-
assert(freqItemsets3.toSet === expected)
66+
FPGrowthModel<String> model3 = fpg
67+
.setMinSupport(0.5)
68+
.setNumPartitions(2)
69+
.run(rdd);
70+
assertEquals(18, model3.javaFreqItemsets().count());
8371

84-
val model2 = fpg
85-
.setMinSupport(0.3)
86-
.setNumPartitions(4)
87-
.run[String](rdd)
88-
assert(model2.freqItemsets.count() == 54)
72+
FPGrowthModel<String> model2 = fpg
73+
.setMinSupport(0.3)
74+
.setNumPartitions(4)
75+
.run(rdd);
76+
assertEquals(54, model2.javaFreqItemsets().count());
8977

90-
val model1 = fpg
91-
.setMinSupport(0.1)
92-
.setNumPartitions(8)
93-
.run[String](rdd)
94-
assert(model1.freqItemsets.count() == 625) */
95-
}
96-
}
78+
FPGrowthModel<String> model1 = fpg
79+
.setMinSupport(0.1)
80+
.setNumPartitions(8)
81+
.run(rdd);
82+
assertEquals(625, model1.javaFreqItemsets().count());
83+
}
84+
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,21 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
3131
"x z y m t s q e",
3232
"z",
3333
"x z y r q t p")
34-
.map(_.split(" ").toSeq)
34+
.map(_.split(" "))
3535
val rdd = sc.parallelize(transactions, 2).cache()
3636

3737
val fpg = new FPGrowth()
3838

3939
val model6 = fpg
4040
.setMinSupport(0.9)
4141
.setNumPartitions(1)
42-
.run[String, Seq[String]](rdd)
42+
.run(rdd)
4343
assert(model6.freqItemsets.count() === 0)
4444

4545
val model3 = fpg
4646
.setMinSupport(0.5)
4747
.setNumPartitions(2)
48-
.run[String, Seq[String]](rdd)
48+
.run(rdd)
4949
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
5050
(items.toSet, count)
5151
}
@@ -62,13 +62,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
6262
val model2 = fpg
6363
.setMinSupport(0.3)
6464
.setNumPartitions(4)
65-
.run[String, Seq[String]](rdd)
65+
.run(rdd)
6666
assert(model2.freqItemsets.count() === 54)
6767

6868
val model1 = fpg
6969
.setMinSupport(0.1)
7070
.setNumPartitions(8)
71-
.run[String, Seq[String]](rdd)
71+
.run(rdd)
7272
assert(model1.freqItemsets.count() === 625)
7373
}
7474

@@ -81,21 +81,23 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
8181
"2 4",
8282
"1 3",
8383
"1 7")
84-
.map(_.split(" ").map(_.toInt).toList)
84+
.map(_.split(" ").map(_.toInt).toArray)
8585
val rdd = sc.parallelize(transactions, 2).cache()
8686

8787
val fpg = new FPGrowth()
8888

8989
val model6 = fpg
9090
.setMinSupport(0.9)
9191
.setNumPartitions(1)
92-
.run[Int, List[Int]](rdd)
92+
.run(rdd)
9393
assert(model6.freqItemsets.count() === 0)
9494

9595
val model3 = fpg
9696
.setMinSupport(0.5)
9797
.setNumPartitions(2)
98-
.run[Int, List[Int]](rdd)
98+
.run(rdd)
99+
assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
100+
"frequent itemsets should use primitive arrays")
99101
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
100102
(items.toSet, count)
101103
}
@@ -108,13 +110,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
108110
val model2 = fpg
109111
.setMinSupport(0.3)
110112
.setNumPartitions(4)
111-
.run[Int, List[Int]](rdd)
113+
.run(rdd)
112114
assert(model2.freqItemsets.count() === 15)
113115

114116
val model1 = fpg
115117
.setMinSupport(0.1)
116118
.setNumPartitions(8)
117-
.run[Int, List[Int]](rdd)
119+
.run(rdd)
118120
assert(model1.freqItemsets.count() === 65)
119121
}
120122
}

0 commit comments

Comments
 (0)