Skip to content

Commit df9538a

Browse files
GeorgeGeorge
authored andcommitted
Added argmax to sparse vector and added unit test
1 parent 3cffed4 commit df9538a

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -728,18 +728,11 @@ class SparseVector(
728728
var maxValue = values(0)
729729

730730
foreachActive { (i, v) =>
731-
if(values(i) > maxValue){
731+
if(v > maxValue){
732732
maxIdx = i
733733
maxValue = v
734734
}
735735
}
736-
// while(i < this.indices.size){
737-
// if(values(i) > maxValue){
738-
// maxIdx = indices(i)
739-
// maxValue = values(i)
740-
// }
741-
// i += 1
742-
// }
743736
maxIdx
744737
}
745738
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class VectorsSuite extends FunSuite {
8686
val vec2 = Vectors.sparse(n,indices,values).asInstanceOf[SparseVector]
8787
val max = vec2.argmax
8888
assert(max === 3)
89+
90+
val vec3 = Vectors.sparse(5,Array(1,3,4),Array(1.0,.5,.7))
91+
val max2 = vec3.argmax
92+
assert(max2 === 1)
8993
}
9094

9195
test("vector equals") {

0 commit comments

Comments
 (0)