Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 3cffed4

Browse files
GeorgeGeorge
authored andcommitted
Adding unit tests for argmax functions for Dense and Sparse vectors
1 parent 04677af commit 3cffed4

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,15 +724,22 @@ class SparseVector(
724724
if (size == 0) {
725725
-1
726726
} else {
727-
var maxIdx = 0
727+
var maxIdx = indices(0)
728728
var maxValue = values(0)
729-
var i = 1
730-
foreachActive{ (i, v) =>
731-
if(v > maxValue) {
729+
730+
foreachActive { (i, v) =>
731+
if(values(i) > maxValue){
732732
maxIdx = i
733733
maxValue = v
734734
}
735735
}
736+
// while(i < this.indices.size){
737+
// if(values(i) > maxValue){
738+
// maxIdx = indices(i)
739+
// maxValue = values(i)
740+
// }
741+
// i += 1
742+
// }
736743
maxIdx
737744
}
738745
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class VectorsSuite extends FunSuite {
4242
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
4343
assert(vec.size === arr.length)
4444
assert(vec.values.eq(arr))
45-
vec.argmax
4645
}
4746

4847
test("sparse vector construction") {
@@ -57,20 +56,38 @@ class VectorsSuite extends FunSuite {
5756
assert(vec.size === n)
5857
assert(vec.indices === indices)
5958
assert(vec.values === values)
60-
val vec2 = Vectors.sparse(5,Array(0,3),values).asInstanceOf[SparseVector]
61-
vec2.foreachActive( (i, v) => println(i,v))
6259
}
6360

6461
test("dense to array") {
6562
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
6663
assert(vec.toArray.eq(arr))
6764
}
6865

66+
test("dense argmax"){
67+
val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
68+
val noMax = vec.argmax
69+
assert(noMax === -1)
70+
71+
val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
72+
val max = vec2.argmax
73+
assert(max === 3)
74+
}
75+
6976
test("sparse to array") {
7077
val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
7178
assert(vec.toArray === arr)
7279
}
7380

81+
test("sparse argmax"){
82+
val vec = Vectors.sparse(0,Array.empty[Int],Array.empty[Double]).asInstanceOf[SparseVector]
83+
val noMax = vec.argmax
84+
assert(noMax === -1)
85+
86+
val vec2 = Vectors.sparse(n,indices,values).asInstanceOf[SparseVector]
87+
val max = vec2.argmax
88+
assert(max === 3)
89+
}
90+
7491
test("vector equals") {
7592
val dv1 = Vectors.dense(arr.clone())
7693
val dv2 = Vectors.dense(arr.clone())

0 commit comments

Comments
 (0)