@@ -43,6 +43,8 @@ class PrefixSpan private (
43
43
private var minSupport : Double ,
44
44
private var maxPatternLength : Int ) extends Logging with Serializable {
45
45
46
+ private val minPatternsBeforeShuffle : Int = 20
47
+
46
48
/**
47
49
* Constructs a default instance with default parameters
48
50
* {minSupport: `0.1`, maxPatternLength: `10`}.
@@ -86,16 +88,69 @@ class PrefixSpan private (
86
88
getFreqItemAndCounts(minCount, sequences).collect()
87
89
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
88
90
lengthOnePatternsAndCounts.map(_._1), sequences)
89
- val groupedProjectedDatabase = prefixAndProjectedDatabase
90
- .map(x => (x._1.toSeq, x._2))
91
- .groupByKey()
92
- .map(x => (x._1.toArray, x._2.toArray))
93
- val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94
- val lengthOnePatternsAndCountsRdd =
95
- sequences.sparkContext.parallelize(
96
- lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
97
- val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98
- allPatterns
91
+
92
+ var patternsCount = lengthOnePatternsAndCounts.length
93
+ var allPatternAndCounts = sequences.sparkContext.parallelize(
94
+ lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
95
+ var currentProjectedDatabase = prefixAndProjectedDatabase
96
+ while (patternsCount <= minPatternsBeforeShuffle &&
97
+ currentProjectedDatabase.count() != 0 ) {
98
+ val (nextPatternAndCounts, nextProjectedDatabase) =
99
+ getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
100
+ patternsCount = nextPatternAndCounts.count().toInt
101
+ currentProjectedDatabase = nextProjectedDatabase
102
+ allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
103
+ }
104
+ if (patternsCount > 0 ) {
105
+ val groupedProjectedDatabase = currentProjectedDatabase
106
+ .map(x => (x._1.toSeq, x._2))
107
+ .groupByKey()
108
+ .map(x => (x._1.toArray, x._2.toArray))
109
+ val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
110
+ allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
111
+ }
112
+ allPatternAndCounts
113
+ }
114
+
115
+ /**
116
+ * Get the pattern and counts, and projected database
117
+ * @param minCount minimum count
118
+ * @param prefixAndProjectedDatabase prefix and projected database,
119
+ * @return pattern and counts, and projected database
120
+ * (Array[pattern, count], RDD[prefix, projected database ])
121
+ */
122
+ private def getPatternCountsAndProjectedDatabase (
123
+ minCount : Long ,
124
+ prefixAndProjectedDatabase : RDD [(Array [Int ], Array [Int ])]):
125
+ (RDD [(Array [Int ], Long )], RDD [(Array [Int ], Array [Int ])]) = {
126
+ val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
127
+ x._2.distinct.map(y => ((x._1.toSeq, y), 1L ))
128
+ }.reduceByKey(_+_)
129
+ .filter(_._2 >= minCount)
130
+ val patternAndCounts = prefixAndFreqentItemAndCounts
131
+ .map(x => (x._1._1.toArray ++ Array (x._1._2), x._2))
132
+ val prefixlength = prefixAndProjectedDatabase.take(1 )(0 )._1.length
133
+ if (prefixlength + 1 >= maxPatternLength) {
134
+ (patternAndCounts, prefixAndProjectedDatabase.filter(x => false ))
135
+ } else {
136
+ val frequentItemsMap = prefixAndFreqentItemAndCounts
137
+ .keys.map(x => (x._1, x._2))
138
+ .groupByKey()
139
+ .mapValues(_.toSet)
140
+ .collect
141
+ .toMap
142
+ val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
143
+ .filter(x => frequentItemsMap.contains(x._1))
144
+ .flatMap { x =>
145
+ val frequentItemSet = frequentItemsMap(x._1)
146
+ val filteredSequence = x._2.filter(frequentItemSet.contains(_))
147
+ val subProjectedDabase = frequentItemSet.map{ y =>
148
+ (y, LocalPrefixSpan .getSuffix(y, filteredSequence))
149
+ }.filter(_._2.nonEmpty)
150
+ subProjectedDabase.map(y => (x._1 ++ Array (y._1), y._2))
151
+ }
152
+ (patternAndCounts, nextPrefixAndProjectedDatabase)
153
+ }
99
154
}
100
155
101
156
/**
0 commit comments