Skip to content

Commit baa2885

Browse files
committed
Modified the code according to the review comments.
1 parent 6560c69 commit baa2885

File tree

1 file changed

+37
-40
lines changed

1 file changed

+37
-40
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -84,72 +84,69 @@ class PrefixSpan private (
8484
logWarning("Input data is not cached.")
8585
}
8686
val minCount = getMinCount(sequences)
87-
val lengthOnePatternsAndCounts =
88-
getFreqItemAndCounts(minCount, sequences).collect()
89-
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
90-
lengthOnePatternsAndCounts.map(_._1), sequences)
91-
92-
var patternsCount = lengthOnePatternsAndCounts.length
93-
var allPatternAndCounts = sequences.sparkContext.parallelize(
94-
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
95-
var currentProjectedDatabase = prefixAndProjectedDatabase
96-
while (patternsCount <= minPatternsBeforeShuffle &&
97-
currentProjectedDatabase.count() != 0) {
98-
val (nextPatternAndCounts, nextProjectedDatabase) =
99-
getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
87+
val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences)
88+
val prefixSuffixPairs = getPrefixSuffixPairs(
89+
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
90+
var patternsCount: Long = lengthOnePatternsAndCounts.count()
91+
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))
92+
var currentPrefixSuffixPairs = prefixSuffixPairs
93+
while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0) {
94+
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
95+
getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs)
10096
patternsCount = nextPatternAndCounts.count().toInt
101-
currentProjectedDatabase = nextProjectedDatabase
97+
currentPrefixSuffixPairs = nextPrefixSuffixPairs
10298
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
10399
}
104100
if (patternsCount > 0) {
105-
val groupedProjectedDatabase = currentProjectedDatabase
101+
val projectedDatabase = currentPrefixSuffixPairs
106102
.map(x => (x._1.toSeq, x._2))
107103
.groupByKey()
108104
.map(x => (x._1.toArray, x._2.toArray))
109-
val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
105+
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
110106
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
111107
}
112108
allPatternAndCounts
113109
}
114110

115111
/**
116-
* Get the pattern and counts, and projected database
112+
* Get the pattern and counts, and prefix suffix pairs
117113
* @param minCount minimum count
118-
* @param prefixAndProjectedDatabase prefix and projected database,
119-
* @return pattern and counts, and projected database
120-
* (Array[pattern, count], RDD[prefix, projected database ])
114+
* @param prefixSuffixPairs prefix and suffix pairs,
115+
* @return pattern and counts, and prefix suffix pairs
116+
* (Array[pattern, count], RDD[prefix, suffix ])
121117
*/
122-
private def getPatternCountsAndProjectedDatabase(
118+
private def getPatternCountsAndPrefixSuffixPairs(
123119
minCount: Long,
124-
prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]):
120+
prefixSuffixPairs: RDD[(Array[Int], Array[Int])]):
125121
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
126-
val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
127-
x._2.distinct.map(y => ((x._1.toSeq, y), 1L))
122+
val prefixAndFreqentItemAndCounts = prefixSuffixPairs
123+
.flatMap { case (prefix, suffix) =>
124+
suffix.distinct.map(y => ((prefix.toSeq, y), 1L))
128125
}.reduceByKey(_ + _)
129126
.filter(_._2 >= minCount)
130127
val patternAndCounts = prefixAndFreqentItemAndCounts
131-
.map(x => (x._1._1.toArray ++ Array(x._1._2), x._2))
132-
val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length
128+
.map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) }
129+
val prefixlength = prefixSuffixPairs.first()._1.length
133130
if (prefixlength + 1 >= maxPatternLength) {
134-
(patternAndCounts, prefixAndProjectedDatabase.filter(x => false))
131+
(patternAndCounts, prefixSuffixPairs.filter(x => false))
135132
} else {
136133
val frequentItemsMap = prefixAndFreqentItemAndCounts
137-
.keys.map(x => (x._1, x._2))
134+
.keys
138135
.groupByKey()
139136
.mapValues(_.toSet)
140137
.collect
141138
.toMap
142-
val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
139+
val nextPrefixSuffixPairs = prefixSuffixPairs
143140
.filter(x => frequentItemsMap.contains(x._1))
144-
.flatMap { x =>
145-
val frequentItemSet = frequentItemsMap(x._1)
146-
val filteredSequence = x._2.filter(frequentItemSet.contains(_))
147-
val subProjectedDabase = frequentItemSet.map{ y =>
148-
(y, LocalPrefixSpan.getSuffix(y, filteredSequence))
141+
.flatMap { case (prefix, suffix) =>
142+
val frequentItemSet = frequentItemsMap(prefix)
143+
val filteredSuffix = suffix.filter(frequentItemSet.contains(_))
144+
val nextSuffixes = frequentItemSet.map{ item =>
145+
(item, LocalPrefixSpan.getSuffix(item, filteredSuffix))
149146
}.filter(_._2.nonEmpty)
150-
subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2))
147+
nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) }
151148
}
152-
(patternAndCounts, nextPrefixAndProjectedDatabase)
149+
(patternAndCounts, nextPrefixSuffixPairs)
153150
}
154151
}
155152

@@ -177,12 +174,12 @@ class PrefixSpan private (
177174
}
178175

179176
/**
180-
* Get the frequent prefixes' projected database.
177+
* Get the frequent prefixes and suffix pairs.
181178
* @param frequentPrefixes frequent prefixes
182179
* @param sequences sequences data
183-
* @return prefixes and projected database
180+
* @return prefixes and suffix pairs.
184181
*/
185-
private def getPrefixAndProjectedDatabase(
182+
private def getPrefixSuffixPairs(
186183
frequentPrefixes: Array[Int],
187184
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
188185
val filteredSequences = sequences.map { p =>
@@ -199,7 +196,7 @@ class PrefixSpan private (
199196
/**
200197
* calculate the patterns in local.
201198
* @param minCount the absolute minimum count
202-
* @param data patterns and projected sequences data data
199+
* @param data prefixes and projected sequences data data
203200
* @return patterns
204201
*/
205202
private def getPatternsInLocal(

0 commit comments

Comments
 (0)