Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit ca9c4c8

Browse files
committed
Modified the code according to the review comments.
1 parent 574e56c commit ca9c4c8

File tree

2 files changed

+27
-65
lines changed

2 files changed

+27
-65
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,35 @@ import org.apache.spark.annotation.Experimental
3030
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
3131

3232
/**
33-
* Calculate all patterns of a projected database in local.
33+
* Calculate all patterns of a projected database.
3434
* @param minCount minimum count
3535
* @param maxPatternLength maximum pattern length
3636
* @param prefix prefix
3737
* @param projectedDatabase the projected dabase
3838
* @return a set of sequential pattern pairs,
39-
* the key of pair is pattern (a list of elements),
39+
* the key of pair is sequential pattern (a list of items),
4040
* the value of pair is the pattern's count.
4141
*/
4242
def run(
4343
minCount: Long,
4444
maxPatternLength: Int,
4545
prefix: Array[Int],
4646
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
47-
getPatternsWithPrefix(minCount, maxPatternLength, prefix, projectedDatabase)
47+
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
48+
val frequentPatternAndCounts = frequentPrefixAndCounts
49+
.map(x => (prefix ++ Array(x._1), x._2))
50+
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
51+
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
52+
53+
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
54+
if (continueProcess) {
55+
val nextPatterns = prefixProjectedDatabases
56+
.map(x => run(minCount, maxPatternLength, x._1, x._2))
57+
.reduce(_ ++ _)
58+
frequentPatternAndCounts ++ nextPatterns
59+
} else {
60+
frequentPatternAndCounts
61+
}
4862
}
4963

5064
/**
@@ -96,34 +110,4 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
96110
(prePrefix ++ Array(x), sub)
97111
}.filter(x => x._2.nonEmpty)
98112
}
99-
100-
/**
101-
* Calculate all patterns of a projected database in local.
102-
* @param minCount the minimum count
103-
* @param maxPatternLength maximum pattern length
104-
* @param prefix prefix
105-
* @param projectedDatabase projected database
106-
* @return patterns
107-
*/
108-
private def getPatternsWithPrefix(
109-
minCount: Long,
110-
maxPatternLength: Int,
111-
prefix: Array[Int],
112-
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
113-
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
114-
val frequentPatternAndCounts = frequentPrefixAndCounts
115-
.map(x => (prefix ++ Array(x._1), x._2))
116-
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
117-
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
118-
119-
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
120-
if (continueProcess) {
121-
val nextPatterns = prefixProjectedDatabases
122-
.map(x => getPatternsWithPrefix(minCount, maxPatternLength, x._1, x._2))
123-
.reduce(_ ++ _)
124-
frequentPatternAndCounts ++ nextPatterns
125-
} else {
126-
frequentPatternAndCounts
127-
}
128-
}
129113
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,15 @@ class PrefixSpan private (
8282
logWarning("Input data is not cached.")
8383
}
8484
val minCount = getMinCount(sequences)
85-
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
86-
findLengthOnePatterns(minCount, sequences)
87-
val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates)
88-
val nextPatterns = getPatternsInLocal(minCount, projectedDatabase)
85+
val lengthOnePatternsAndCounts =
86+
getFreqItemAndCounts(minCount, sequences).collect()
87+
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
88+
lengthOnePatternsAndCounts.map(_._1), sequences)
89+
val groupedProjectedDatabase = prefixAndProjectedDatabase
90+
.map(x => (x._1.toSeq, x._2))
91+
.groupByKey()
92+
.map(x => (x._1.toArray, x._2.toArray))
93+
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
8994
val lengthOnePatternsAndCountsRdd =
9095
sequences.sparkContext.parallelize(
9196
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
@@ -122,7 +127,7 @@ class PrefixSpan private (
122127
* @param sequences sequences data
123128
* @return prefixes and projected database
124129
*/
125-
private def getPatternAndProjectedDatabase(
130+
private def getPrefixAndProjectedDatabase(
126131
frequentPrefixes: Array[Int],
127132
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
128133
val filteredSequences = sequences.map { p =>
@@ -136,33 +141,6 @@ class PrefixSpan private (
136141
}
137142
}
138143

139-
/**
140-
* Find the patterns that it's length is one
141-
* @param minCount the minimum count
142-
* @param sequences original sequences data
143-
* @return length-one patterns and projection table
144-
*/
145-
private def findLengthOnePatterns(
146-
minCount: Long,
147-
sequences: RDD[Array[Int]]): (Array[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
148-
val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
149-
val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
150-
frequentLengthOnePatternAndCounts.keys.collect(), sequences)
151-
(frequentLengthOnePatternAndCounts.collect(), prefixAndProjectedDatabase)
152-
}
153-
154-
/**
155-
* Constructs prefix-projected databases from (prefix, suffix) pairs.
156-
* @param data patterns and projected sequences data before re-partition
157-
* @return patterns and projected sequences data after re-partition
158-
*/
159-
private def makePrefixProjectedDatabases(
160-
data: RDD[(Array[Int], Array[Int])]): RDD[(Array[Int], Array[Array[Int]])] = {
161-
data.map(x => (x._1.toSeq, x._2))
162-
.groupByKey()
163-
.map(x => (x._1.toArray, x._2.toArray))
164-
}
165-
166144
/**
167145
* calculate the patterns in local.
168146
* @param minCount the absolute minimum count

0 commit comments

Comments
 (0)