Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 22b0ef4

Browse files
committed
Add feature: Collect enough frequent prefixes before projection in PrefixSpan.
1 parent ca9c4c8 commit 22b0ef4

File tree

1 file changed

+65
-10
lines changed

1 file changed

+65
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class PrefixSpan private (
4343
private var minSupport: Double,
4444
private var maxPatternLength: Int) extends Logging with Serializable {
4545

46+
private val minPatternsBeforeShuffle: Int = 20
47+
4648
/**
4749
* Constructs a default instance with default parameters
4850
* {minSupport: `0.1`, maxPatternLength: `10`}.
@@ -86,16 +88,69 @@ class PrefixSpan private (
8688
getFreqItemAndCounts(minCount, sequences).collect()
8789
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
8890
lengthOnePatternsAndCounts.map(_._1), sequences)
89-
val groupedProjectedDatabase = prefixAndProjectedDatabase
90-
.map(x => (x._1.toSeq, x._2))
91-
.groupByKey()
92-
.map(x => (x._1.toArray, x._2.toArray))
93-
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94-
val lengthOnePatternsAndCountsRdd =
95-
sequences.sparkContext.parallelize(
96-
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
97-
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98-
allPatterns
91+
92+
var patternsCount = lengthOnePatternsAndCounts.length
93+
var allPatternAndCounts = sequences.sparkContext.parallelize(
94+
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
95+
var currentProjectedDatabase = prefixAndProjectedDatabase
96+
while (patternsCount <= minPatternsBeforeShuffle &&
97+
currentProjectedDatabase.count() != 0) {
98+
val (nextPatternAndCounts, nextProjectedDatabase) =
99+
getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
100+
patternsCount = nextPatternAndCounts.count().toInt
101+
currentProjectedDatabase = nextProjectedDatabase
102+
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
103+
}
104+
if (patternsCount > 0) {
105+
val groupedProjectedDatabase = currentProjectedDatabase
106+
.map(x => (x._1.toSeq, x._2))
107+
.groupByKey()
108+
.map(x => (x._1.toArray, x._2.toArray))
109+
val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
110+
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
111+
}
112+
allPatternAndCounts
113+
}
114+
115+
/**
116+
* Get the pattern and counts, and projected database
117+
* @param minCount minimum count
118+
* @param prefixAndProjectedDatabase prefix and projected database,
119+
* @return pattern and counts, and projected database
120+
* (Array[pattern, count], RDD[prefix, projected database ])
121+
*/
122+
private def getPatternCountsAndProjectedDatabase(
123+
minCount: Long,
124+
prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]):
125+
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
126+
val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
127+
x._2.distinct.map(y => ((x._1.toSeq, y), 1L))
128+
}.reduceByKey(_+_)
129+
.filter(_._2 >= minCount)
130+
val patternAndCounts = prefixAndFreqentItemAndCounts
131+
.map(x => (x._1._1.toArray ++ Array(x._1._2), x._2))
132+
val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length
133+
if (prefixlength + 1 >= maxPatternLength) {
134+
(patternAndCounts, prefixAndProjectedDatabase.filter(x => false))
135+
} else {
136+
val frequentItemsMap = prefixAndFreqentItemAndCounts
137+
.keys.map(x => (x._1, x._2))
138+
.groupByKey()
139+
.mapValues(_.toSet)
140+
.collect
141+
.toMap
142+
val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
143+
.filter(x => frequentItemsMap.contains(x._1))
144+
.flatMap { x =>
145+
val frequentItemSet = frequentItemsMap(x._1)
146+
val filteredSequence = x._2.filter(frequentItemSet.contains(_))
147+
val subProjectedDabase = frequentItemSet.map{ y =>
148+
(y, LocalPrefixSpan.getSuffix(y, filteredSequence))
149+
}.filter(_._2.nonEmpty)
150+
subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2))
151+
}
152+
(patternAndCounts, nextPrefixAndProjectedDatabase)
153+
}
99154
}
100155

101156
/**

0 commit comments

Comments
 (0)