Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 04677af

Browse files
GeorgeGeorge
authored andcommitted
initial work on adding argmax to Vector and SparseVector
1 parent 8c07c75 commit 04677af

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
150150
toDense
151151
}
152152
}
153+
154+
/**
155+
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
156+
* Returns -1 if vector has length 0.
157+
*/
158+
def argmax: Int
153159
}
154160

155161
/**
@@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
588594
new SparseVector(size, ii, vv)
589595
}
590596

591-
/**
592-
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
593-
* Returns -1 if vector has length 0.
594-
*/
595-
private[spark] def argmax: Int = {
597+
def argmax: Int = {
596598
if (size == 0) {
597599
-1
598600
} else {
@@ -717,6 +719,23 @@ class SparseVector(
717719
new SparseVector(size, ii, vv)
718720
}
719721
}
722+
723+
override def argmax: Int = {
724+
if (size == 0) {
725+
-1
726+
} else {
727+
var maxIdx = 0
728+
var maxValue = values(0)
729+
var i = 1
730+
foreachActive{ (i, v) =>
731+
if(v > maxValue) {
732+
maxIdx = i
733+
maxValue = v
734+
}
735+
}
736+
maxIdx
737+
}
738+
}
720739
}
721740

722741
object SparseVector {

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class VectorsSuite extends FunSuite {
4242
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
4343
assert(vec.size === arr.length)
4444
assert(vec.values.eq(arr))
45+
vec.argmax
4546
}
4647

4748
test("sparse vector construction") {
@@ -56,6 +57,8 @@ class VectorsSuite extends FunSuite {
5657
assert(vec.size === n)
5758
assert(vec.indices === indices)
5859
assert(vec.values === values)
60+
val vec2 = Vectors.sparse(5,Array(0,3),values).asInstanceOf[SparseVector]
61+
vec2.foreachActive( (i, v) => println(i,v))
5962
}
6063

6164
test("dense to array") {

0 commit comments

Comments
 (0)