Skip to content

Commit e4bd0c0

Browse files
committed
[SPARK-4409] Modified horzcat and vertcat
1 parent 65c562e commit e4bd0c0

File tree

3 files changed

+145
-67
lines changed

3 files changed

+145
-67
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 125 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ object Matrices {
603603
/**
604604
* Horizontally concatenate a sequence of matrices. The returned matrix will be in the format
605605
* the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in
606-
* a dense matrix.
606+
* a sparse matrix.
607607
* @param matrices array of matrices
608608
* @return a single `Matrix` composed of the matrices that were horizontally concatenated
609609
*/
@@ -621,17 +621,42 @@ object Matrices {
621621
mat match {
622622
case sparse: SparseMatrix => isSparse = true
623623
case dense: DenseMatrix => isDense = true
624+
case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " +
625+
s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}")
624626
}
625627
numCols += mat.numCols
626628
}
627629
require(rowsMatch, "The number of rows of the matrices in this sequence, don't match!")
628630

629-
if (isSparse && !isDense) {
631+
if (!isSparse && isDense) {
632+
new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray).toArray)
633+
} else {
630634
val allColPtrs: Array[(Int, Int)] = Array((0, 0)) ++
631635
matrices.zipWithIndex.flatMap { case (mat, ind) =>
632-
val ptr = mat.asInstanceOf[SparseMatrix].colPtrs
633-
ptr.slice(1, ptr.length).map(p => (ind, p))
634-
}
636+
mat match {
637+
case spMat: SparseMatrix =>
638+
val ptr = spMat.colPtrs
639+
ptr.slice(1, ptr.length).map(p => (ind, p))
640+
case dnMat: DenseMatrix =>
641+
val colSize = dnMat.numCols
642+
var j = 0
643+
val rowSize = dnMat.numRows
644+
val ptr = new ArrayBuffer[(Int, Int)](colSize)
645+
var nnz = 0
646+
val vals = dnMat.values
647+
while (j < colSize) {
648+
var i = j * rowSize
649+
val indEnd = (j + 1) * rowSize
650+
while (i < indEnd) {
651+
if (vals(i) != 0.0) nnz += 1
652+
i += 1
653+
}
654+
j += 1
655+
ptr.append((ind, nnz))
656+
}
657+
ptr
658+
}
659+
}
635660
var counter = 0
636661
var lastIndex = 0
637662
var lastPtr = 0
@@ -643,21 +668,36 @@ object Matrices {
643668
lastPtr = p
644669
counter + p
645670
}
671+
val valsAndIndices: Array[(Int, Double)] = matrices.flatMap {
672+
case spMat: SparseMatrix =>
673+
spMat.rowIndices.zip(spMat.values)
674+
case dnMat: DenseMatrix =>
675+
val colSize = dnMat.numCols
676+
var j = 0
677+
val rowSize = dnMat.numRows
678+
val data = new ArrayBuffer[(Int, Double)]()
679+
val vals = dnMat.values
680+
while (j < colSize) {
681+
val indStart = j * rowSize
682+
var i = 0
683+
while (i < rowSize) {
684+
val index = indStart + i
685+
if (vals(index) != 0.0) data.append((i, vals(index)))
686+
i += 1
687+
}
688+
j += 1
689+
}
690+
data
691+
}
646692
new SparseMatrix(numRows, numCols, adjustedPtrs,
647-
matrices.flatMap(_.asInstanceOf[SparseMatrix].rowIndices).toArray,
648-
matrices.flatMap(_.asInstanceOf[SparseMatrix].values).toArray)
649-
} else if (!isSparse && !isDense) {
650-
throw new IllegalArgumentException("The supplied matrices are neither in SparseMatrix or" +
651-
" DenseMatrix format!")
652-
}else {
653-
new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray).toArray)
693+
valsAndIndices.map(_._1), valsAndIndices.map(_._2))
654694
}
655695
}
656696

657697
/**
658698
* Vertically concatenate a sequence of matrices. The returned matrix will be in the format
659699
* the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in
660-
* a dense matrix.
700+
* a sparse matrix.
661701
* @param matrices array of matrices
662702
* @return a single `Matrix` composed of the matrices that were vertically concatenated
663703
*/
@@ -680,27 +720,58 @@ object Matrices {
680720
case dense: DenseMatrix =>
681721
isDense = true
682722
valsLength += dense.values.length
723+
case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " +
724+
s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}")
683725
}
684726
numRows += mat.numRows
685727

686728
}
687729
require(colsMatch, "The number of rows of the matrices in this sequence, don't match!")
688730

689-
if (isSparse && !isDense) {
690-
val matMap = matrices.zipWithIndex.map(d => (d._2, d._1.asInstanceOf[SparseMatrix])).toMap
691-
// (matrixInd, colInd, colStart, colEnd, numRows)
692-
val allColPtrs: Seq[(Int, Int, Int, Int, Int)] =
693-
matMap.flatMap { case (ind, mat) =>
694-
val ptr = mat.colPtrs
695-
var colStart = 0
696-
var j = 0
697-
ptr.slice(1, ptr.length).map { p =>
698-
j += 1
699-
val oldColStart = colStart
700-
colStart = p
701-
(j - 1, ind, oldColStart, p, mat.numRows)
702-
}
703-
}.toSeq
731+
if (!isSparse && isDense) {
732+
val matData = matrices.zipWithIndex.flatMap { case (mat, ind) =>
733+
val values = mat.toArray
734+
for (j <- 0 until numCols) yield (j, ind,
735+
values.slice(j * mat.numRows, (j + 1) * mat.numRows))
736+
}.sortBy(x => (x._1, x._2))
737+
new DenseMatrix(numRows, numCols, matData.flatMap(_._3).toArray)
738+
} else {
739+
val matMap = matrices.zipWithIndex.map(d => (d._2, d._1)).toMap
740+
// (colInd, matrixInd, colStart, colEnd, numRows)
741+
val allColPtrs: Seq[(Int, Int, Int, Int, Int)] = matMap.flatMap { case (ind, mat) =>
742+
mat match {
743+
case spMat: SparseMatrix =>
744+
val ptr = spMat.colPtrs
745+
var colStart = 0
746+
var j = 0
747+
ptr.slice(1, ptr.length).map { p =>
748+
j += 1
749+
val oldColStart = colStart
750+
colStart = p
751+
(j - 1, ind, oldColStart, p, spMat.numRows)
752+
}
753+
case dnMat: DenseMatrix =>
754+
val colSize = dnMat.numCols
755+
var j = 0
756+
val rowSize = dnMat.numRows
757+
val ptr = new ArrayBuffer[(Int, Int, Int, Int, Int)](colSize)
758+
var nnz = 0
759+
val vals = dnMat.values
760+
var colStart = 0
761+
while (j < colSize) {
762+
var i = j * rowSize
763+
val indEnd = (j + 1) * rowSize
764+
while (i < indEnd) {
765+
if (vals(i) != 0.0) nnz += 1
766+
i += 1
767+
}
768+
ptr.append((j, ind, colStart, nnz, dnMat.numRows))
769+
j += 1
770+
colStart = nnz
771+
}
772+
ptr
773+
}
774+
}.toSeq
704775
val values = new ArrayBuffer[Double](valsLength)
705776
val rowInd = new ArrayBuffer[Int](valsLength)
706777
val newColPtrs = new Array[Int](numCols)
@@ -712,31 +783,38 @@ object Matrices {
712783
var startRow = 0
713784
sortedPtrs.foreach { case (colIdx, matrixInd, colStart, colEnd, nRows) =>
714785
val selectedMatrix = matMap(matrixInd)
715-
val selectedValues = selectedMatrix.values.slice(colStart, colEnd)
716-
val selectedRowIdx = selectedMatrix.rowIndices.slice(colStart, colEnd)
717-
val len = selectedValues.length
718-
newColPtrs(colIdx) += len
719-
var i = 0
720-
while (i < len) {
721-
values.append(selectedValues(i))
722-
rowInd.append(selectedRowIdx(i) + startRow)
723-
i += 1
786+
selectedMatrix match {
787+
case spMat: SparseMatrix =>
788+
val selectedValues = spMat.values
789+
val selectedRowIdx = spMat.rowIndices
790+
val len = colEnd - colStart
791+
newColPtrs(colIdx) += len
792+
var i = colStart
793+
while (i < colEnd) {
794+
values.append(selectedValues(i))
795+
rowInd.append(selectedRowIdx(i) + startRow)
796+
i += 1
797+
}
798+
case dnMat: DenseMatrix =>
799+
val selectedValues = dnMat.values
800+
val len = colEnd - colStart
801+
newColPtrs(colIdx) += len
802+
val indStart = colIdx * nRows
803+
var i = 0
804+
while (i < nRows) {
805+
val v = selectedValues(indStart + i)
806+
if (v != 0) {
807+
values.append(v)
808+
rowInd.append(i + startRow)
809+
}
810+
i += 1
811+
}
724812
}
725813
startRow += nRows
726814
}
727815
}
728816
val adjustedPtrs = newColPtrs.scanLeft(0)(_ + _)
729817
new SparseMatrix(numRows, numCols, adjustedPtrs, rowInd.toArray, values.toArray)
730-
} else if (!isSparse && !isDense) {
731-
throw new IllegalArgumentException("The supplied matrices are neither in SparseMatrix or" +
732-
" DenseMatrix format!")
733-
}else {
734-
val matData = matrices.zipWithIndex.flatMap { case (mat, ind) =>
735-
val values = mat.toArray
736-
for (j <- 0 until numCols) yield (j, ind,
737-
values.slice(j * mat.numRows, (j + 1) * mat.numRows))
738-
}.sortBy(x => (x._1, x._2))
739-
new DenseMatrix(numRows, numCols, matData.flatMap(_._3).toArray)
740818
}
741819
}
742820
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,21 +152,21 @@ class MatricesSuite extends FunSuite {
152152
val spMat3 = Matrices.speye(2)
153153

154154
val spHorz = Matrices.horzcat(Array(spMat1, spMat2))
155+
val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2))
156+
val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2))
155157
val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2))
156-
val deHorz2 = Matrices.horzcat(Array(spMat1, deMat2))
157-
val deHorz3 = Matrices.horzcat(Array(deMat1, spMat2))
158158

159159
assert(deHorz1.numRows === 3)
160-
assert(deHorz2.numRows === 3)
161-
assert(deHorz3.numRows === 3)
160+
assert(spHorz2.numRows === 3)
161+
assert(spHorz3.numRows === 3)
162162
assert(spHorz.numRows === 3)
163163
assert(deHorz1.numCols === 5)
164-
assert(deHorz2.numCols === 5)
165-
assert(deHorz3.numCols === 5)
164+
assert(spHorz2.numCols === 5)
165+
assert(spHorz3.numCols === 5)
166166
assert(spHorz.numCols === 5)
167167

168-
assert(deHorz1 === deHorz2)
169-
assert(deHorz2 === deHorz3)
168+
assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix)
169+
assert(spHorz2.toBreeze === spHorz3.toBreeze)
170170
assert(spHorz(0, 0) === 1.0)
171171
assert(spHorz(2, 1) === 5.0)
172172
assert(spHorz(0, 2) === 1.0)
@@ -177,7 +177,7 @@ class MatricesSuite extends FunSuite {
177177
assert(deHorz1(0, 0) === 1.0)
178178
assert(deHorz1(2, 1) === 5.0)
179179
assert(deHorz1(0, 2) === 1.0)
180-
assert(deHorz1(1, 2) === 0.0)
180+
assert(deHorz1(1, 2) == 0.0)
181181
assert(deHorz1(1, 3) === 1.0)
182182
assert(deHorz1(2, 4) === 1.0)
183183
assert(deHorz1(1, 4) === 0.0)
@@ -192,20 +192,20 @@ class MatricesSuite extends FunSuite {
192192

193193
val spVert = Matrices.vertcat(Array(spMat1, spMat3))
194194
val deVert1 = Matrices.vertcat(Array(deMat1, deMat3))
195-
val deVert2 = Matrices.vertcat(Array(spMat1, deMat3))
196-
val deVert3 = Matrices.vertcat(Array(deMat1, spMat3))
195+
val spVert2 = Matrices.vertcat(Array(spMat1, deMat3))
196+
val spVert3 = Matrices.vertcat(Array(deMat1, spMat3))
197197

198198
assert(deVert1.numRows === 5)
199-
assert(deVert2.numRows === 5)
200-
assert(deVert3.numRows === 5)
199+
assert(spVert2.numRows === 5)
200+
assert(spVert3.numRows === 5)
201201
assert(spVert.numRows === 5)
202202
assert(deVert1.numCols === 2)
203-
assert(deVert2.numCols === 2)
204-
assert(deVert3.numCols === 2)
203+
assert(spVert2.numCols === 2)
204+
assert(spVert3.numCols === 2)
205205
assert(spVert.numCols === 2)
206206

207-
assert(deVert1 === deVert2)
208-
assert(deVert2 === deVert3)
207+
assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix)
208+
assert(spVert2.toBreeze === spVert3.toBreeze)
209209
assert(spVert(0, 0) === 1.0)
210210
assert(spVert(2, 1) === 5.0)
211211
assert(spVert(3, 0) === 1.0)

mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,17 @@ object TestingUtils {
178178
implicit class MatrixWithAlmostEquals(val x: Matrix) {
179179

180180
/**
181-
* When the difference of two vectors are within eps, returns true; otherwise, returns false.
181+
* When the difference of two matrices are within eps, returns true; otherwise, returns false.
182182
*/
183183
def ~=(r: CompareMatrixRightSide): Boolean = r.fun(x, r.y, r.eps)
184184

185185
/**
186-
* When the difference of two vectors are within eps, returns false; otherwise, returns true.
186+
* When the difference of two matrices are within eps, returns false; otherwise, returns true.
187187
*/
188188
def !~=(r: CompareMatrixRightSide): Boolean = !r.fun(x, r.y, r.eps)
189189

190190
/**
191-
* Throws exception when the difference of two vectors are NOT within eps;
191+
* Throws exception when the difference of two matrices are NOT within eps;
192192
* otherwise, returns true.
193193
*/
194194
def ~==(r: CompareMatrixRightSide): Boolean = {

0 commit comments

Comments
 (0)