Skip to content

Commit 7783351

Browse files
author
Jacky Li
committed
add generic support in FPGrowth
1 parent bebf4c4 commit 7783351

File tree

2 files changed

+82
-23
lines changed

2 files changed

+82
-23
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,23 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20+
import java.lang.{Iterable => JavaIterable}
2021
import java.{util => ju}
2122

23+
import scala.collection.JavaConverters._
2224
import scala.collection.mutable
25+
import scala.reflect.ClassTag
2326

24-
import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
27+
import org.apache.spark.api.java.JavaRDD
2528
import org.apache.spark.rdd.RDD
2629
import org.apache.spark.storage.StorageLevel
30+
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
2731

28-
class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
32+
class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
33+
def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = {
34+
freqItemsets.toJavaRDD()
35+
}
36+
}
2937

3038
/**
3139
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -69,32 +77,36 @@ class FPGrowth private (
6977
* @param data input data set, each element contains a transaction
7078
* @return an [[FPGrowthModel]]
7179
*/
72-
def run(data: RDD[Array[String]]): FPGrowthModel = {
80+
def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = {
7381
if (data.getStorageLevel == StorageLevel.NONE) {
7482
logWarning("Input data is not cached.")
7583
}
7684
val count = data.count()
7785
val minCount = math.ceil(minSupport * count).toLong
7886
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
7987
val partitioner = new HashPartitioner(numParts)
80-
val freqItems = genFreqItems(data, minCount, partitioner)
81-
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
88+
val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner)
89+
val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner)
8290
new FPGrowthModel(freqItemsets)
8391
}
8492

93+
def run[Item: ClassTag, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
94+
this.run(data.rdd.map(_.asScala))
95+
}
96+
8597
/**
8698
* Generates frequent items by filtering the input data using minimal support level.
8799
* @param minCount minimum count for frequent itemsets
88100
* @param partitioner partitioner used to distribute items
89101
* @return array of frequent pattern ordered by their frequencies
90102
*/
91-
private def genFreqItems(
92-
data: RDD[Array[String]],
103+
private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]](
104+
data: RDD[Basket],
93105
minCount: Long,
94-
partitioner: Partitioner): Array[String] = {
106+
partitioner: Partitioner): Array[Item] = {
95107
data.flatMap { t =>
96108
val uniq = t.toSet
97-
if (t.length != uniq.size) {
109+
if (t.size != uniq.size) {
98110
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
99111
}
100112
t
@@ -114,11 +126,11 @@ class FPGrowth private (
114126
* @param partitioner partitioner used to distribute transactions
115127
* @return an RDD of (frequent itemset, count)
116128
*/
117-
private def genFreqItemsets(
118-
data: RDD[Array[String]],
129+
private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]](
130+
data: RDD[Basket],
119131
minCount: Long,
120-
freqItems: Array[String],
121-
partitioner: Partitioner): RDD[(Array[String], Long)] = {
132+
freqItems: Array[Item],
133+
partitioner: Partitioner): RDD[(Array[Item], Long)] = {
122134
val itemToRank = freqItems.zipWithIndex.toMap
123135
data.flatMap { transaction =>
124136
genCondTransactions(transaction, itemToRank, partitioner)
@@ -139,13 +151,13 @@ class FPGrowth private (
139151
* @param partitioner partitioner used to distribute transactions
140152
* @return a map of (target partition, conditional transaction)
141153
*/
142-
private def genCondTransactions(
143-
transaction: Array[String],
144-
itemToRank: Map[String, Int],
154+
private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]](
155+
transaction: Basket,
156+
itemToRank: Map[Item, Int],
145157
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
146158
val output = mutable.Map.empty[Int, Array[Int]]
147159
// Filter the basket by frequent items pattern and sort their ranks.
148-
val filtered = transaction.flatMap(itemToRank.get)
160+
val filtered = transaction.flatMap(itemToRank.get).toArray
149161
ju.Arrays.sort(filtered)
150162
val n = filtered.length
151163
var i = n - 1

mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,30 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2222

2323
class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
2424

25-
test("FP-Growth") {
25+
26+
test("FP-Growth using String type") {
2627
val transactions = Seq(
2728
"r z h k p",
2829
"z y x w v u t s",
2930
"s x o n r",
3031
"x z y m t s q e",
3132
"z",
3233
"x z y r q t p")
33-
.map(_.split(" "))
34+
.map(_.split(" ").toSeq)
3435
val rdd = sc.parallelize(transactions, 2).cache()
3536

3637
val fpg = new FPGrowth()
3738

3839
val model6 = fpg
3940
.setMinSupport(0.9)
4041
.setNumPartitions(1)
41-
.run(rdd)
42+
.run[String, Seq[String]](rdd)
4243
assert(model6.freqItemsets.count() === 0)
4344

4445
val model3 = fpg
4546
.setMinSupport(0.5)
4647
.setNumPartitions(2)
47-
.run(rdd)
48+
.run[String, Seq[String]](rdd)
4849
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
4950
(items.toSet, count)
5051
}
@@ -61,13 +62,59 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
6162
val model2 = fpg
6263
.setMinSupport(0.3)
6364
.setNumPartitions(4)
64-
.run(rdd)
65+
.run[String, Seq[String]](rdd)
6566
assert(model2.freqItemsets.count() === 54)
6667

6768
val model1 = fpg
6869
.setMinSupport(0.1)
6970
.setNumPartitions(8)
70-
.run(rdd)
71+
.run[String, Seq[String]](rdd)
7172
assert(model1.freqItemsets.count() === 625)
7273
}
74+
75+
test("FP-Growth using Int type") {
76+
val transactions = Seq(
77+
"1 2 3",
78+
"1 2 3 4",
79+
"5 4 3 2 1",
80+
"6 5 4 3 2 1",
81+
"2 4",
82+
"1 3",
83+
"1 7")
84+
.map(_.split(" ").map(_.toInt).toList)
85+
val rdd = sc.parallelize(transactions, 2).cache()
86+
87+
val fpg = new FPGrowth()
88+
89+
val model6 = fpg
90+
.setMinSupport(0.9)
91+
.setNumPartitions(1)
92+
.run[Int, List[Int]](rdd)
93+
assert(model6.freqItemsets.count() === 0)
94+
95+
val model3 = fpg
96+
.setMinSupport(0.5)
97+
.setNumPartitions(2)
98+
.run[Int, List[Int]](rdd)
99+
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
100+
(items.toSet, count)
101+
}
102+
val expected = Set(
103+
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
104+
(Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
105+
(Set(2, 4), 4L), (Set(1, 2, 3), 4L))
106+
assert(freqItemsets3.toSet === expected)
107+
108+
val model2 = fpg
109+
.setMinSupport(0.3)
110+
.setNumPartitions(4)
111+
.run[Int, List[Int]](rdd)
112+
assert(model2.freqItemsets.count() === 15)
113+
114+
val model1 = fpg
115+
.setMinSupport(0.1)
116+
.setNumPartitions(8)
117+
.run[Int, List[Int]](rdd)
118+
assert(model1.freqItemsets.count() === 65)
119+
}
73120
}

0 commit comments

Comments
 (0)