17
17
18
18
package org .apache .spark .mllib .fpm
19
19
20
- import java .lang .{Iterable => JavaIterable }
21
20
import java .{util => ju }
21
+ import java .lang .{Iterable => JavaIterable }
22
22
23
- import scala .collection .JavaConverters ._
24
23
import scala .collection .mutable
24
+ import scala .collection .JavaConverters ._
25
25
import scala .reflect .ClassTag
26
26
27
- import org .apache .spark .api .java .JavaRDD
27
+ import org .apache .spark .{HashPartitioner , Logging , Partitioner , SparkException }
28
+ import org .apache .spark .api .java .{JavaPairRDD , JavaRDD }
29
+ import org .apache .spark .api .java .JavaSparkContext .fakeClassTag
28
30
import org .apache .spark .rdd .RDD
29
31
import org .apache .spark .storage .StorageLevel
30
- import org .apache .spark .{HashPartitioner , Logging , Partitioner , SparkException }
31
32
32
- class FPGrowthModel [Item ](val freqItemsets : RDD [(Array [Item ], Long )]) extends Serializable {
33
- def javaFreqItemsets (): JavaRDD [(Array [Item ], Long )] = {
34
- freqItemsets.toJavaRDD()
33
+ /**
34
+ * Model trained by [[FPGrowth ]], which holds frequent itemsets.
35
+ * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
36
+ * @tparam Item item type
37
+ */
38
+ class FPGrowthModel [Item : ClassTag ](
39
+ val freqItemsets : RDD [(Array [Item ], Long )]) extends Serializable {
40
+
41
+ /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD ]]. */
42
+ def javaFreqItemsets (): JavaPairRDD [Array [Item ], java.lang.Long ] = {
43
+ JavaPairRDD .fromRDD(freqItemsets).asInstanceOf [JavaPairRDD [Array [Item ], java.lang.Long ]]
35
44
}
36
45
}
37
46
@@ -77,22 +86,22 @@ class FPGrowth private (
77
86
* @param data input data set, each element contains a transaction
78
87
* @return an [[FPGrowthModel ]]
79
88
*/
80
- def run [Item : ClassTag , Basket <: Iterable [ Item ]] (data : RDD [Basket ]): FPGrowthModel [Item ] = {
89
+ def run [Item : ClassTag ] (data : RDD [Array [ Item ] ]): FPGrowthModel [Item ] = {
81
90
if (data.getStorageLevel == StorageLevel .NONE ) {
82
91
logWarning(" Input data is not cached." )
83
92
}
84
93
val count = data.count()
85
94
val minCount = math.ceil(minSupport * count).toLong
86
95
val numParts = if (numPartitions > 0 ) numPartitions else data.partitions.length
87
96
val partitioner = new HashPartitioner (numParts)
88
- val freqItems = genFreqItems[ Item , Basket ] (data, minCount, partitioner)
89
- val freqItemsets = genFreqItemsets[ Item , Basket ] (data, minCount, freqItems, partitioner)
97
+ val freqItems = genFreqItems(data, minCount, partitioner)
98
+ val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
90
99
new FPGrowthModel (freqItemsets)
91
100
}
92
101
93
- def run [Item : ClassTag , Basket <: JavaIterable [Item ]](
94
- data : JavaRDD [ Basket ]) : FPGrowthModel [Item ] = {
95
- this . run(data.rdd.map(_.asScala))
102
+ def run [Item , Basket <: JavaIterable [Item ]](data : JavaRDD [ Basket ]) : FPGrowthModel [ Item ] = {
103
+ implicit val tag = fakeClassTag [Item ]
104
+ run(data.rdd.map(_.asScala.toArray ))
96
105
}
97
106
98
107
/**
@@ -101,8 +110,8 @@ class FPGrowth private (
101
110
* @param partitioner partitioner used to distribute items
102
111
* @return array of frequent pattern ordered by their frequencies
103
112
*/
104
- private def genFreqItems [Item : ClassTag , Basket <: Iterable [ Item ] ](
105
- data : RDD [Basket ],
113
+ private def genFreqItems [Item : ClassTag ](
114
+ data : RDD [Array [ Item ] ],
106
115
minCount : Long ,
107
116
partitioner : Partitioner ): Array [Item ] = {
108
117
data.flatMap { t =>
@@ -127,8 +136,8 @@ class FPGrowth private (
127
136
* @param partitioner partitioner used to distribute transactions
128
137
* @return an RDD of (frequent itemset, count)
129
138
*/
130
- private def genFreqItemsets [Item : ClassTag , Basket <: Iterable [ Item ] ](
131
- data : RDD [Basket ],
139
+ private def genFreqItemsets [Item : ClassTag ](
140
+ data : RDD [Array [ Item ] ],
132
141
minCount : Long ,
133
142
freqItems : Array [Item ],
134
143
partitioner : Partitioner ): RDD [(Array [Item ], Long )] = {
@@ -152,13 +161,13 @@ class FPGrowth private (
152
161
* @param partitioner partitioner used to distribute transactions
153
162
* @return a map of (target partition, conditional transaction)
154
163
*/
155
- private def genCondTransactions [Item : ClassTag , Basket <: Iterable [ Item ] ](
156
- transaction : Basket ,
164
+ private def genCondTransactions [Item : ClassTag ](
165
+ transaction : Array [ Item ] ,
157
166
itemToRank : Map [Item , Int ],
158
167
partitioner : Partitioner ): mutable.Map [Int , Array [Int ]] = {
159
168
val output = mutable.Map .empty[Int , Array [Int ]]
160
169
// Filter the basket by frequent items pattern and sort their ranks.
161
- val filtered = transaction.flatMap(itemToRank.get).toArray
170
+ val filtered = transaction.flatMap(itemToRank.get)
162
171
ju.Arrays .sort(filtered)
163
172
val n = filtered.length
164
173
var i = n - 1
0 commit comments