Skip to content

Commit 85c48de

Browse files
committed
minor updates
1 parent a048d0c commit 85c48de

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,57 +46,49 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
4646
* correlation between column i and j.
4747
*/
4848
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
49-
// ((columnIndex, value), rowId)
49+
// ((columnIndex, value), rowUid)
5050
val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
5151
vec.toArray.view.zipWithIndex.map { case (v, j) =>
5252
((j, v), uid)
5353
}
5454
}
5555
// global sort by (columnIndex, value)
5656
val sorted = colBased.sortByKey()
57-
// Assign global ranks (using average ranks for tied values)
57+
// assign global ranks (using average ranks for tied values)
5858
val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
5959
var preCol = -1
6060
var preVal = Double.NaN
6161
var startRank = -1.0
62-
var cachedIds = ArrayBuffer.empty[Long]
63-
def flush: () => Iterable[(Long, (Int, Double))] = () => {
64-
val averageRank = startRank + (cachedIds.size - 1) / 2.0
65-
val output = cachedIds.map { i =>
66-
(i, (preCol, averageRank))
62+
var cachedUids = ArrayBuffer.empty[Long]
63+
val flush: () => Iterable[(Long, (Int, Double))] = () => {
64+
val averageRank = startRank + (cachedUids.size - 1) / 2.0
65+
val output = cachedUids.map { uid =>
66+
(uid, (preCol, averageRank))
6767
}
68-
cachedIds.clear()
68+
cachedUids.clear()
6969
output
7070
}
7171
iter.flatMap { case (((j, v), uid), rank) =>
72-
// If we see a new value or cachedIds is too big, we flush ids with their average rank.
73-
if (j != preCol || v != preVal || cachedIds.size >= 10000000) {
72+
// If we see a new value or cachedUids is too big, we flush ids with their average rank.
73+
if (j != preCol || v != preVal || cachedUids.size >= 10000000) {
7474
val output = flush()
7575
preCol = j
7676
preVal = v
7777
startRank = rank
78-
cachedIds += uid
78+
cachedUids += uid
7979
output
8080
} else {
81-
cachedIds += uid
81+
cachedUids += uid
8282
Iterator.empty
8383
}
84-
} ++ {
85-
flush()
86-
}
84+
} ++ flush()
8785
}
8886
// Replace values in the input matrix by their ranks compared with values in the same column.
8987
// Note that shifting all ranks in a column by a constant value doesn't affect result.
9088
val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
9189
// sort by column index and then convert values to a vector
9290
Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
9391
}
94-
val corrMatrix = PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
95-
96-
colBased.unpersist(blocking = false)
97-
sorted.unpersist(blocking = false)
98-
99-
corrMatrix
92+
PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
10093
}
10194
}
102-

0 commit comments

Comments
 (0)