Skip to content

Commit f21dcce

Browse files
committed
commit
1 parent af17981 commit f21dcce

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
594594
new SparseVector(size, ii, vv)
595595
}
596596

597-
def argmax: Int = {
597+
override def argmax: Int = {
598598
if (size == 0) {
599599
-1
600600
} else {
@@ -726,17 +726,42 @@ class SparseVector(
726726
} else {
727727

728728
var maxIdx = 0
729-
var maxValue = if(indices(0) != 0) 0 else values(0)
729+
var maxValue = if(indices(0) != 0) 0.0 else values(0)
730730

731731
foreachActive { (i, v) =>
732-
if(v > maxValue){
732+
if (v > maxValue) {
733733
maxIdx = i
734734
maxValue = v
735735
}
736736
}
737+
738+
// look for inactive values incase all active node values are negative
739+
if(size != values.size && maxValue < 0){
740+
maxIdx = calcInactiveIdx(indices(0))
741+
maxValue = 0
742+
}
737743
maxIdx
738744
}
739745
}
746+
747+
/**
748+
* Calculates the first instance of an inactive node in a sparse vector and returns the Idx
749+
* of the element.
750+
* @param idx starting index of computation
751+
* @return index of first inactive node or -1 if it cannot find one
752+
*/
753+
private[SparseVector] def calcInactiveIdx(idx: Int): Int ={
754+
if(idx < size){
755+
if(!indices.contains(idx)){
756+
idx
757+
}else{
758+
calcInactiveIdx(idx+1)
759+
}
760+
}else{
761+
-1
762+
}
763+
}
764+
740765
}
741766

742767
object SparseVector {

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,23 @@ class VectorsSuite extends FunSuite {
9191
val max = vec2.argmax
9292
assert(max === 3)
9393

94-
// check for case that sparse vector is created with only negative vaues {0.0,0.0,-1.0,0.0,-0.7}
95-
val vec3 = Vectors.sparse(5,Array(2, 4),Array(-1.0,-.7))
94+
val vec3 = Vectors.sparse(5,Array(2, 4),Array(1.0,-.7))
9695
val max2 = vec3.argmax
97-
assert(max2 === 0)
96+
assert(max2 === 2)
97+
98+
// check for case that sparse vector is created with only negative vaues {0.0, 0.0,-1.0, -0.7, 0.0}
99+
val vec4 = Vectors.sparse(5,Array(2, 3),Array(-1.0,-.7))
100+
val max3 = vec4.argmax
101+
assert(max3 === 0)
102+
103+
// check for case that sparse vector is created with only negative vaues {-1.0, 0.0, -0.7, 0.0, 0.0}
104+
val vec5 = Vectors.sparse(5,Array(0, 3),Array(-1.0,-.7))
105+
val max4 = vec5.argmax
106+
assert(max4 === 1)
107+
108+
val vec6 = Vectors.sparse(5,Array(0, 1, 2),Array(-1.0, -.025, -.7))
109+
val max5 = vec6.argmax
110+
assert(max5 === 3)
98111
}
99112

100113
test("vector equals") {

0 commit comments

Comments
 (0)