@@ -46,57 +46,49 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
46
46
* correlation between column i and j.
47
47
*/
48
48
override def computeCorrelationMatrix (X : RDD [Vector ]): Matrix = {
49
- // ((columnIndex, value), rowId )
49
+ // ((columnIndex, value), rowUid )
50
50
val colBased = X .zipWithUniqueId().flatMap { case (vec, uid) =>
51
51
vec.toArray.view.zipWithIndex.map { case (v, j) =>
52
52
((j, v), uid)
53
53
}
54
54
}
55
55
// global sort by (columnIndex, value)
56
56
val sorted = colBased.sortByKey()
57
- // Assign global ranks (using average ranks for tied values)
57
+ // assign global ranks (using average ranks for tied values)
58
58
val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
59
59
var preCol = - 1
60
60
var preVal = Double .NaN
61
61
var startRank = - 1.0
62
- var cachedIds = ArrayBuffer .empty[Long ]
63
- def flush : () => Iterable [(Long , (Int , Double ))] = () => {
64
- val averageRank = startRank + (cachedIds .size - 1 ) / 2.0
65
- val output = cachedIds .map { i =>
66
- (i , (preCol, averageRank))
62
+ var cachedUids = ArrayBuffer .empty[Long ]
63
+ val flush : () => Iterable [(Long , (Int , Double ))] = () => {
64
+ val averageRank = startRank + (cachedUids .size - 1 ) / 2.0
65
+ val output = cachedUids .map { uid =>
66
+ (uid , (preCol, averageRank))
67
67
}
68
- cachedIds .clear()
68
+ cachedUids .clear()
69
69
output
70
70
}
71
71
iter.flatMap { case (((j, v), uid), rank) =>
72
- // If we see a new value or cachedIds is too big, we flush ids with their average rank.
73
- if (j != preCol || v != preVal || cachedIds .size >= 10000000 ) {
72
+ // If we see a new value or cachedUids is too big, we flush ids with their average rank.
73
+ if (j != preCol || v != preVal || cachedUids .size >= 10000000 ) {
74
74
val output = flush()
75
75
preCol = j
76
76
preVal = v
77
77
startRank = rank
78
- cachedIds += uid
78
+ cachedUids += uid
79
79
output
80
80
} else {
81
- cachedIds += uid
81
+ cachedUids += uid
82
82
Iterator .empty
83
83
}
84
- } ++ {
85
- flush()
86
- }
84
+ } ++ flush()
87
85
}
88
86
// Replace values in the input matrix by their ranks compared with values in the same column.
89
87
// Note that shifting all ranks in a column by a constant value doesn't affect result.
90
88
val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
91
89
// sort by column index and then convert values to a vector
92
90
Vectors .dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
93
91
}
94
- val corrMatrix = PearsonCorrelation .computeCorrelationMatrix(groupedRanks)
95
-
96
- colBased.unpersist(blocking = false )
97
- sorted.unpersist(blocking = false )
98
-
99
- corrMatrix
92
+ PearsonCorrelation .computeCorrelationMatrix(groupedRanks)
100
93
}
101
94
}
102
-
0 commit comments