Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit b1f059f

Browse files
committed
Added comment before we start arg max calculation. Updated unit tests to cover corner cases
1 parent f21dcce commit b1f059f

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,9 @@ class SparseVector(
725725
-1
726726
} else {
727727

728-
var maxIdx = 0
729-
var maxValue = if(indices(0) != 0) 0.0 else values(0)
728+
//grab first active index and value by default
729+
var maxIdx = indices(0)
730+
var maxValue = values(0)
730731

731732
foreachActive { (i, v) =>
732733
if (v > maxValue) {
@@ -736,8 +737,8 @@ class SparseVector(
736737
}
737738

738739
// look for inactive values incase all active node values are negative
739-
if(size != values.size && maxValue < 0){
740-
maxIdx = calcInactiveIdx(indices(0))
740+
if(size != values.size && maxValue <= 0){
741+
maxIdx = calcInactiveIdx(0)
741742
maxValue = 0
742743
}
743744
maxIdx
@@ -748,20 +749,19 @@ class SparseVector(
748749
* Calculates the first instance of an inactive node in a sparse vector and returns the Idx
749750
* of the element.
750751
* @param idx starting index of computation
751-
* @return index of first inactive node or -1 if it cannot find one
752+
* @return index of first inactive node
752753
*/
753-
private[SparseVector] def calcInactiveIdx(idx: Int): Int ={
754-
if(idx < size){
755-
if(!indices.contains(idx)){
754+
private[SparseVector] def calcInactiveIdx(idx: Int): Int = {
755+
if (idx < size) {
756+
if (!indices.contains(idx)) {
756757
idx
757-
}else{
758-
calcInactiveIdx(idx+1)
758+
} else {
759+
calcInactiveIdx(idx + 1)
759760
}
760-
}else{
761+
} else {
761762
-1
762763
}
763764
}
764-
765765
}
766766

767767
object SparseVector {

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ class VectorsSuite extends FunSuite {
105105
val max4 = vec5.argmax
106106
assert(max4 === 1)
107107

108-
val vec6 = Vectors.sparse(5,Array(0, 1, 2),Array(-1.0, -.025, -.7))
108+
val vec6 = Vectors.sparse(2,Array(0, 1),Array(-1.0, 0.0))
109109
val max5 = vec6.argmax
110-
assert(max5 === 3)
110+
assert(max5 === 1)
111111
}
112112

113113
test("vector equals") {

0 commit comments

Comments
 (0)