17
17
18
18
package org .apache .spark .mllib .fpm
19
19
20
+ import java .lang .{Iterable => JavaIterable }
20
21
import java .{util => ju }
21
22
23
+ import scala .collection .JavaConverters ._
22
24
import scala .collection .mutable
25
+ import scala .reflect .ClassTag
23
26
24
- import org .apache .spark .{ SparkException , HashPartitioner , Logging , Partitioner }
27
+ import org .apache .spark .api . java . JavaRDD
25
28
import org .apache .spark .rdd .RDD
26
29
import org .apache .spark .storage .StorageLevel
30
+ import org .apache .spark .{HashPartitioner , Logging , Partitioner , SparkException }
27
31
28
- class FPGrowthModel (val freqItemsets : RDD [(Array [String ], Long )]) extends Serializable
32
+ class FPGrowthModel [Item ](val freqItemsets : RDD [(Array [Item ], Long )]) extends Serializable {
33
+ def javaFreqItemsets (): JavaRDD [(Array [Item ], Long )] = {
34
+ freqItemsets.toJavaRDD()
35
+ }
36
+ }
29
37
30
38
/**
31
39
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -69,32 +77,36 @@ class FPGrowth private (
69
77
* @param data input data set, each element contains a transaction
70
78
* @return an [[FPGrowthModel ]]
71
79
*/
72
- def run (data : RDD [Array [ String ]] ): FPGrowthModel = {
80
+ def run [ Item : ClassTag , Basket <: Iterable [ Item ]] (data : RDD [Basket ] ): FPGrowthModel [ Item ] = {
73
81
if (data.getStorageLevel == StorageLevel .NONE ) {
74
82
logWarning(" Input data is not cached." )
75
83
}
76
84
val count = data.count()
77
85
val minCount = math.ceil(minSupport * count).toLong
78
86
val numParts = if (numPartitions > 0 ) numPartitions else data.partitions.length
79
87
val partitioner = new HashPartitioner (numParts)
80
- val freqItems = genFreqItems(data, minCount, partitioner)
81
- val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
88
+ val freqItems = genFreqItems[ Item , Basket ] (data, minCount, partitioner)
89
+ val freqItemsets = genFreqItemsets[ Item , Basket ] (data, minCount, freqItems, partitioner)
82
90
new FPGrowthModel (freqItemsets)
83
91
}
84
92
93
+ def run [Item : ClassTag , Basket <: JavaIterable [Item ]](data : JavaRDD [Basket ]): FPGrowthModel [Item ] = {
94
+ this .run(data.rdd.map(_.asScala))
95
+ }
96
+
85
97
/**
86
98
* Generates frequent items by filtering the input data using minimal support level.
87
99
* @param minCount minimum count for frequent itemsets
88
100
* @param partitioner partitioner used to distribute items
89
101
* @return array of frequent pattern ordered by their frequencies
90
102
*/
91
- private def genFreqItems (
92
- data : RDD [Array [ String ] ],
103
+ private def genFreqItems [ Item : ClassTag , Basket <: Iterable [ Item ]] (
104
+ data : RDD [Basket ],
93
105
minCount : Long ,
94
- partitioner : Partitioner ): Array [String ] = {
106
+ partitioner : Partitioner ): Array [Item ] = {
95
107
data.flatMap { t =>
96
108
val uniq = t.toSet
97
- if (t.length != uniq.size) {
109
+ if (t.size != uniq.size) {
98
110
throw new SparkException (s " Items in a transaction must be unique but got ${t.toSeq}. " )
99
111
}
100
112
t
@@ -114,11 +126,11 @@ class FPGrowth private (
114
126
* @param partitioner partitioner used to distribute transactions
115
127
* @return an RDD of (frequent itemset, count)
116
128
*/
117
- private def genFreqItemsets (
118
- data : RDD [Array [ String ] ],
129
+ private def genFreqItemsets [ Item : ClassTag , Basket <: Iterable [ Item ]] (
130
+ data : RDD [Basket ],
119
131
minCount : Long ,
120
- freqItems : Array [String ],
121
- partitioner : Partitioner ): RDD [(Array [String ], Long )] = {
132
+ freqItems : Array [Item ],
133
+ partitioner : Partitioner ): RDD [(Array [Item ], Long )] = {
122
134
val itemToRank = freqItems.zipWithIndex.toMap
123
135
data.flatMap { transaction =>
124
136
genCondTransactions(transaction, itemToRank, partitioner)
@@ -139,13 +151,13 @@ class FPGrowth private (
139
151
* @param partitioner partitioner used to distribute transactions
140
152
* @return a map of (target partition, conditional transaction)
141
153
*/
142
- private def genCondTransactions (
143
- transaction : Array [ String ] ,
144
- itemToRank : Map [String , Int ],
154
+ private def genCondTransactions [ Item : ClassTag , Basket <: Iterable [ Item ]] (
155
+ transaction : Basket ,
156
+ itemToRank : Map [Item , Int ],
145
157
partitioner : Partitioner ): mutable.Map [Int , Array [Int ]] = {
146
158
val output = mutable.Map .empty[Int , Array [Int ]]
147
159
// Filter the basket by frequent items pattern and sort their ranks.
148
- val filtered = transaction.flatMap(itemToRank.get)
160
+ val filtered = transaction.flatMap(itemToRank.get).toArray
149
161
ju.Arrays .sort(filtered)
150
162
val n = filtered.length
151
163
var i = n - 1
0 commit comments