Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 3ee8711

Browse files
committed
Fixing corner case issue with zeros in the active values of the sparse vector. Updated unit tests
1 parent b1f059f commit 3ee8711

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,6 @@ class SparseVector(
725725
-1
726726
} else {
727727

728-
//grab first active index and value by default
729728
var maxIdx = indices(0)
730729
var maxValue = values(0)
731730

@@ -736,9 +735,14 @@ class SparseVector(
736735
}
737736
}
738737

739-
// look for inactive values incase all active node values are negative
738+
// look for inactive values in case all active node values are negative
740739
if(size != values.size && maxValue <= 0){
741-
maxIdx = calcInactiveIdx(0)
740+
val firstInactiveIdx = calcFirstInactiveIdx(0)
741+
if(maxValue == 0){
742+
if(firstInactiveIdx >= maxIdx) maxIdx else maxIdx = firstInactiveIdx
743+
}else{
744+
maxIdx = firstInactiveIdx
745+
}
742746
maxValue = 0
743747
}
744748
maxIdx
@@ -751,12 +755,12 @@ class SparseVector(
751755
* @param idx starting index of computation
752756
* @return index of first inactive node
753757
*/
754-
private[SparseVector] def calcInactiveIdx(idx: Int): Int = {
758+
private[SparseVector] def calcFirstInactiveIdx(idx: Int): Int = {
755759
if (idx < size) {
756760
if (!indices.contains(idx)) {
757761
idx
758762
} else {
759-
calcInactiveIdx(idx + 1)
763+
calcFirstInactiveIdx(idx + 1)
760764
}
761765
} else {
762766
-1

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,24 @@ class VectorsSuite extends FunSuite {
9595
val max2 = vec3.argmax
9696
assert(max2 === 2)
9797

98-
// check for case that sparse vector is created with only negative vaues {0.0, 0.0,-1.0, -0.7, 0.0}
99-
val vec4 = Vectors.sparse(5,Array(2, 3),Array(-1.0,-.7))
98+
// check for case that sparse vector is created with only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
99+
val vec4 = Vectors.sparse(5,Array(0, 1, 2, 3),Array(0.0, 0.0, -1.0,-.7))
100100
val max3 = vec4.argmax
101101
assert(max3 === 0)
102102

103-
// check for case that sparse vector is created with only negative vaues {-1.0, 0.0, -0.7, 0.0, 0.0}
104-
val vec5 = Vectors.sparse(5,Array(0, 3),Array(-1.0,-.7))
103+
val vec5 = Vectors.sparse(11,Array(0, 3, 10),Array(-1.0,-.7,0.0))
105104
val max4 = vec5.argmax
106105
assert(max4 === 1)
107106

108-
val vec6 = Vectors.sparse(2,Array(0, 1),Array(-1.0, 0.0))
107+
val vec6 = Vectors.sparse(5,Array(0, 1, 3),Array(-1.0, 0.0, -.7))
109108
val max5 = vec6.argmax
110109
assert(max5 === 1)
110+
111+
// test that converting the sparse vector to another sparse vector then calling argmax still works right
112+
var vec8 = Vectors.sparse(5,Array(0, 1),Array(0.0, -1.0))
113+
vec8 = vec8.toSparse
114+
val max7 = vec8.argmax
115+
assert(max7 === 0)
111116
}
112117

113118
test("vector equals") {

0 commit comments

Comments
 (0)