Skip to content

Commit 1bb8acc

Browse files
Feynman Liangmengxr
authored andcommitted
[SPARK-8997] [MLLIB] Performance improvements in LocalPrefixSpan
Improves the performance of LocalPrefixSpan by implementing optimizations proposed in [SPARK-8997](https://issues.apache.org/jira/browse/SPARK-8997) Author: Feynman Liang <[email protected]> Author: Feynman Liang <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#7360 from feynmanliang/SPARK-8997-improve-prefixspan and squashes the following commits: 59db2f5 [Feynman Liang] Merge pull request #1 from mengxr/SPARK-8997 91e4357 [Xiangrui Meng] update LocalPrefixSpan impl 9212256 [Feynman Liang] MengXR code review comments f055d82 [Feynman Liang] Fix failing scalatest 2e00cba [Feynman Liang] Depth first projections 70b93e3 [Feynman Liang] Performance improvements in LocalPrefixSpan, fix tests
1 parent f0e1297 commit 1bb8acc

File tree

3 files changed

+44
-70
lines changed

3 files changed

+44
-70
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,97 +17,78 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.Logging
21-
import org.apache.spark.annotation.Experimental
2223

2324
/**
24-
*
25-
* :: Experimental ::
26-
*
2725
* Calculate all patterns of a projected database in local.
2826
*/
29-
@Experimental
3027
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
3128

3229
/**
3330
* Calculate all patterns of a projected database.
3431
* @param minCount minimum count
3532
* @param maxPatternLength maximum pattern length
36-
* @param prefix prefix
37-
* @param projectedDatabase the projected dabase
33+
* @param prefixes prefixes in reversed order
34+
* @param database the projected database
3835
* @return a set of sequential pattern pairs,
39-
* the key of pair is sequential pattern (a list of items),
36+
* the key of pair is sequential pattern (a list of items in reversed order),
4037
* the value of pair is the pattern's count.
4138
*/
4239
def run(
4340
minCount: Long,
4441
maxPatternLength: Int,
45-
prefix: Array[Int],
46-
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
47-
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
48-
val frequentPatternAndCounts = frequentPrefixAndCounts
49-
.map(x => (prefix ++ Array(x._1), x._2))
50-
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
51-
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
52-
53-
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
54-
if (continueProcess) {
55-
val nextPatterns = prefixProjectedDatabases
56-
.map(x => run(minCount, maxPatternLength, x._1, x._2))
57-
.reduce(_ ++ _)
58-
frequentPatternAndCounts ++ nextPatterns
59-
} else {
60-
frequentPatternAndCounts
42+
prefixes: List[Int],
43+
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
44+
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
45+
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
46+
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
47+
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
48+
val newPrefixes = item :: prefixes
49+
val newProjected = project(filteredDatabase, item)
50+
Iterator.single((newPrefixes, count)) ++
51+
run(minCount, maxPatternLength, newPrefixes, newProjected)
6152
}
6253
}
6354

6455
/**
65-
* calculate suffix sequence following a prefix in a sequence
66-
* @param prefix prefix
67-
* @param sequence sequence
56+
* Calculate suffix sequence immediately after the first occurrence of an item.
57+
* @param item item to get suffix after
58+
* @param sequence sequence to extract suffix from
6859
* @return suffix sequence
6960
*/
70-
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
71-
val index = sequence.indexOf(prefix)
61+
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
62+
val index = sequence.indexOf(item)
7263
if (index == -1) {
7364
Array()
7465
} else {
7566
sequence.drop(index + 1)
7667
}
7768
}
7869

70+
def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
71+
database
72+
.map(getSuffix(prefix, _))
73+
.filter(_.nonEmpty)
74+
}
75+
7976
/**
8077
* Generates frequent items by filtering the input data using minimal count level.
81-
* @param minCount the absolute minimum count
82-
* @param sequences sequences data
83-
* @return array of item and count pair
78+
* @param minCount the minimum count for an item to be frequent
79+
* @param database database of sequences
80+
* @return freq item to count map
8481
*/
8582
private def getFreqItemAndCounts(
8683
minCount: Long,
87-
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
88-
sequences.flatMap(_.distinct)
89-
.groupBy(x => x)
90-
.mapValues(_.length.toLong)
91-
.filter(_._2 >= minCount)
92-
.toArray
93-
}
94-
95-
/**
96-
* Get the frequent prefixes' projected database.
97-
* @param prePrefix the frequent prefixes' prefix
98-
* @param frequentPrefixes frequent prefixes
99-
* @param sequences sequences data
100-
* @return prefixes and projected database
101-
*/
102-
private def getPatternAndProjectedDatabase(
103-
prePrefix: Array[Int],
104-
frequentPrefixes: Array[Int],
105-
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
106-
val filteredProjectedDatabase = sequences
107-
.map(x => x.filter(frequentPrefixes.contains(_)))
108-
frequentPrefixes.map { x =>
109-
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
110-
(prePrefix ++ Array(x), sub)
111-
}.filter(x => x._2.nonEmpty)
84+
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
85+
// TODO: use PrimitiveKeyOpenHashMap
86+
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
87+
database.foreach { sequence =>
88+
sequence.distinct.foreach { item =>
89+
counts(item) += 1L
90+
}
91+
}
92+
counts.filter(_._2 >= minCount)
11293
}
11394
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ class PrefixSpan private (
150150
private def getPatternsInLocal(
151151
minCount: Long,
152152
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
153-
data.flatMap { x =>
154-
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
153+
data.flatMap { case (prefix, projDB) =>
154+
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
155+
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
155156
}
156157
}
157158
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm
1818

1919
import org.apache.spark.SparkFunSuite
2020
import org.apache.spark.mllib.util.MLlibTestSparkContext
21-
import org.apache.spark.rdd.RDD
2221

23-
class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
22+
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
2423

2524
test("PrefixSpan using Integer type") {
2625

@@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
4847
def compareResult(
4948
expectedValue: Array[(Array[Int], Long)],
5049
actualValue: Array[(Array[Int], Long)]): Boolean = {
51-
val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
52-
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
53-
}
54-
val sortedActualValue = actualValue.sortWith{ (x, y) =>
55-
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
56-
}
57-
sortedExpectedValue.zip(sortedActualValue)
58-
.map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
59-
.reduce(_&&_)
50+
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
51+
actualValue.map(x => (x._1.toSeq, x._2)).toSet
6052
}
6153

6254
val prefixspan = new PrefixSpan()

0 commit comments

Comments
 (0)