Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 095aa3a

Browse files
committed
Modified the code according to the review comments.
1 parent baa2885 commit 095aa3a

File tree

1 file changed

+44
-43
lines changed

1 file changed

+44
-43
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.apache.spark.Logging
2123
import org.apache.spark.annotation.Experimental
2224
import org.apache.spark.rdd.RDD
@@ -43,7 +45,7 @@ class PrefixSpan private (
4345
private var minSupport: Double,
4446
private var maxPatternLength: Int) extends Logging with Serializable {
4547

46-
private val minPatternsBeforeShuffle: Int = 20
48+
private val minPatternsBeforeLocalProcessing: Int = 20
4749

4850
/**
4951
* Constructs a default instance with default parameters
@@ -88,66 +90,65 @@ class PrefixSpan private (
8890
val prefixSuffixPairs = getPrefixSuffixPairs(
8991
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
9092
var patternsCount: Long = lengthOnePatternsAndCounts.count()
91-
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))
93+
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2))
9294
var currentPrefixSuffixPairs = prefixSuffixPairs
93-
while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0) {
95+
var patternLength: Int = 1
96+
while (patternLength < maxPatternLength &&
97+
patternsCount <= minPatternsBeforeLocalProcessing &&
98+
currentPrefixSuffixPairs.count() != 0) {
9499
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
95100
getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs)
96-
patternsCount = nextPatternAndCounts.count().toInt
101+
patternsCount = nextPatternAndCounts.count()
97102
currentPrefixSuffixPairs = nextPrefixSuffixPairs
98103
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
104+
patternLength = patternLength + 1
99105
}
100-
if (patternsCount > 0) {
106+
if (patternLength < maxPatternLength && patternsCount > 0) {
101107
val projectedDatabase = currentPrefixSuffixPairs
102108
.map(x => (x._1.toSeq, x._2))
103109
.groupByKey()
104110
.map(x => (x._1.toArray, x._2.toArray))
105111
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
106112
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
107113
}
108-
allPatternAndCounts
114+
allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
109115
}
110116

111117
/**
112118
* Get the pattern and counts, and prefix suffix pairs
113119
* @param minCount minimum count
114-
* @param prefixSuffixPairs prefix and suffix pairs,
115-
* @return pattern and counts, and prefix suffix pairs
116-
* (Array[pattern, count], RDD[prefix, suffix ])
120+
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
121+
* @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs
122+
* (RDD[pattern, count], RDD[prefix, suffix ])
117123
*/
118124
private def getPatternCountsAndPrefixSuffixPairs(
119125
minCount: Long,
120-
prefixSuffixPairs: RDD[(Array[Int], Array[Int])]):
121-
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
122-
val prefixAndFreqentItemAndCounts = prefixSuffixPairs
123-
.flatMap { case (prefix, suffix) =>
124-
suffix.distinct.map(y => ((prefix.toSeq, y), 1L))
125-
}.reduceByKey(_ + _)
126+
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
127+
(RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = {
128+
val prefixAndFrequentItemAndCounts = prefixSuffixPairs
129+
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
130+
.reduceByKey(_ + _)
126131
.filter(_._2 >= minCount)
127-
val patternAndCounts = prefixAndFreqentItemAndCounts
128-
.map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) }
129-
val prefixlength = prefixSuffixPairs.first()._1.length
130-
if (prefixlength + 1 >= maxPatternLength) {
131-
(patternAndCounts, prefixSuffixPairs.filter(x => false))
132-
} else {
133-
val frequentItemsMap = prefixAndFreqentItemAndCounts
134-
.keys
135-
.groupByKey()
136-
.mapValues(_.toSet)
137-
.collect
138-
.toMap
139-
val nextPrefixSuffixPairs = prefixSuffixPairs
140-
.filter(x => frequentItemsMap.contains(x._1))
141-
.flatMap { case (prefix, suffix) =>
142-
val frequentItemSet = frequentItemsMap(prefix)
143-
val filteredSuffix = suffix.filter(frequentItemSet.contains(_))
144-
val nextSuffixes = frequentItemSet.map{ item =>
145-
(item, LocalPrefixSpan.getSuffix(item, filteredSuffix))
146-
}.filter(_._2.nonEmpty)
147-
nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) }
132+
val patternAndCounts = prefixAndFrequentItemAndCounts
133+
.map { case ((prefix, item), count) => (prefix :+ item, count) }
134+
val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts
135+
.keys
136+
.groupByKey()
137+
.mapValues(_.toSet)
138+
.collect()
139+
.toMap
140+
val nextPrefixSuffixPairs = prefixSuffixPairs
141+
.filter(x => prefixToFrequentNextItemsMap.contains(x._1))
142+
.flatMap { case (prefix, suffix) =>
143+
val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
144+
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
145+
frequentNextItems.flatMap { item =>
146+
val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix)
147+
if (suffix.isEmpty) None
148+
else Some(prefix :+ item, suffix)
148149
}
149-
(patternAndCounts, nextPrefixSuffixPairs)
150150
}
151+
(patternAndCounts, nextPrefixSuffixPairs)
151152
}
152153

153154
/**
@@ -181,14 +182,14 @@ class PrefixSpan private (
181182
*/
182183
private def getPrefixSuffixPairs(
183184
frequentPrefixes: Array[Int],
184-
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
185+
sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = {
185186
val filteredSequences = sequences.map { p =>
186187
p.filter (frequentPrefixes.contains(_) )
187188
}
188189
filteredSequences.flatMap { x =>
189190
frequentPrefixes.map { y =>
190191
val sub = LocalPrefixSpan.getSuffix(y, x)
191-
(Array(y), sub)
192+
(ArrayBuffer(y), sub)
192193
}.filter(_._2.nonEmpty)
193194
}
194195
}
@@ -201,9 +202,9 @@ class PrefixSpan private (
201202
*/
202203
private def getPatternsInLocal(
203204
minCount: Long,
204-
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
205-
data.flatMap { x =>
206-
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
207-
}
205+
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = {
206+
data
207+
.flatMap { x => LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) }
208+
.map { case (pattern, count) => (pattern.to[ArrayBuffer], count) }
208209
}
209210
}

0 commit comments

Comments
 (0)