@@ -603,7 +603,7 @@ object Matrices {
603
603
/**
604
604
* Horizontally concatenate a sequence of matrices. The returned matrix will be in the format
605
605
* the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in
606
- * a dense matrix.
606
+ * a sparse matrix.
607
607
* @param matrices array of matrices
608
608
* @return a single `Matrix` composed of the matrices that were horizontally concatenated
609
609
*/
@@ -621,17 +621,42 @@ object Matrices {
621
621
mat match {
622
622
case sparse : SparseMatrix => isSparse = true
623
623
case dense : DenseMatrix => isDense = true
624
+ case _ => throw new IllegalArgumentException (" Unsupported matrix format. Expected " +
625
+ s " SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}" )
624
626
}
625
627
numCols += mat.numCols
626
628
}
627
629
require(rowsMatch, " The number of rows of the matrices in this sequence, don't match!" )
628
630
629
- if (isSparse && ! isDense) {
631
+ if (! isSparse && isDense) {
632
+ new DenseMatrix (numRows, numCols, matrices.flatMap(_.toArray).toArray)
633
+ } else {
630
634
val allColPtrs : Array [(Int , Int )] = Array ((0 , 0 )) ++
631
635
matrices.zipWithIndex.flatMap { case (mat, ind) =>
632
- val ptr = mat.asInstanceOf [SparseMatrix ].colPtrs
633
- ptr.slice(1 , ptr.length).map(p => (ind, p))
634
- }
636
+ mat match {
637
+ case spMat : SparseMatrix =>
638
+ val ptr = spMat.colPtrs
639
+ ptr.slice(1 , ptr.length).map(p => (ind, p))
640
+ case dnMat : DenseMatrix =>
641
+ val colSize = dnMat.numCols
642
+ var j = 0
643
+ val rowSize = dnMat.numRows
644
+ val ptr = new ArrayBuffer [(Int , Int )](colSize)
645
+ var nnz = 0
646
+ val vals = dnMat.values
647
+ while (j < colSize) {
648
+ var i = j * rowSize
649
+ val indEnd = (j + 1 ) * rowSize
650
+ while (i < indEnd) {
651
+ if (vals(i) != 0.0 ) nnz += 1
652
+ i += 1
653
+ }
654
+ j += 1
655
+ ptr.append((ind, nnz))
656
+ }
657
+ ptr
658
+ }
659
+ }
635
660
var counter = 0
636
661
var lastIndex = 0
637
662
var lastPtr = 0
@@ -643,21 +668,36 @@ object Matrices {
643
668
lastPtr = p
644
669
counter + p
645
670
}
671
+ val valsAndIndices : Array [(Int , Double )] = matrices.flatMap {
672
+ case spMat : SparseMatrix =>
673
+ spMat.rowIndices.zip(spMat.values)
674
+ case dnMat : DenseMatrix =>
675
+ val colSize = dnMat.numCols
676
+ var j = 0
677
+ val rowSize = dnMat.numRows
678
+ val data = new ArrayBuffer [(Int , Double )]()
679
+ val vals = dnMat.values
680
+ while (j < colSize) {
681
+ val indStart = j * rowSize
682
+ var i = 0
683
+ while (i < rowSize) {
684
+ val index = indStart + i
685
+ if (vals(index) != 0.0 ) data.append((i, vals(index)))
686
+ i += 1
687
+ }
688
+ j += 1
689
+ }
690
+ data
691
+ }
646
692
new SparseMatrix (numRows, numCols, adjustedPtrs,
647
- matrices.flatMap(_.asInstanceOf [SparseMatrix ].rowIndices).toArray,
648
- matrices.flatMap(_.asInstanceOf [SparseMatrix ].values).toArray)
649
- } else if (! isSparse && ! isDense) {
650
- throw new IllegalArgumentException (" The supplied matrices are neither in SparseMatrix or" +
651
- " DenseMatrix format!" )
652
- }else {
653
- new DenseMatrix (numRows, numCols, matrices.flatMap(_.toArray).toArray)
693
+ valsAndIndices.map(_._1), valsAndIndices.map(_._2))
654
694
}
655
695
}
656
696
657
697
/**
658
698
* Vertically concatenate a sequence of matrices. The returned matrix will be in the format
659
699
* the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in
660
- * a dense matrix.
700
+ * a sparse matrix.
661
701
* @param matrices array of matrices
662
702
* @return a single `Matrix` composed of the matrices that were vertically concatenated
663
703
*/
@@ -680,27 +720,58 @@ object Matrices {
680
720
case dense : DenseMatrix =>
681
721
isDense = true
682
722
valsLength += dense.values.length
723
+ case _ => throw new IllegalArgumentException (" Unsupported matrix format. Expected " +
724
+ s " SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}" )
683
725
}
684
726
numRows += mat.numRows
685
727
686
728
}
687
729
require(colsMatch, " The number of rows of the matrices in this sequence, don't match!" )
688
730
689
- if (isSparse && ! isDense) {
690
- val matMap = matrices.zipWithIndex.map(d => (d._2, d._1.asInstanceOf [SparseMatrix ])).toMap
691
- // (matrixInd, colInd, colStart, colEnd, numRows)
692
- val allColPtrs : Seq [(Int , Int , Int , Int , Int )] =
693
- matMap.flatMap { case (ind, mat) =>
694
- val ptr = mat.colPtrs
695
- var colStart = 0
696
- var j = 0
697
- ptr.slice(1 , ptr.length).map { p =>
698
- j += 1
699
- val oldColStart = colStart
700
- colStart = p
701
- (j - 1 , ind, oldColStart, p, mat.numRows)
702
- }
703
- }.toSeq
731
+ if (! isSparse && isDense) {
732
+ val matData = matrices.zipWithIndex.flatMap { case (mat, ind) =>
733
+ val values = mat.toArray
734
+ for (j <- 0 until numCols) yield (j, ind,
735
+ values.slice(j * mat.numRows, (j + 1 ) * mat.numRows))
736
+ }.sortBy(x => (x._1, x._2))
737
+ new DenseMatrix (numRows, numCols, matData.flatMap(_._3).toArray)
738
+ } else {
739
+ val matMap = matrices.zipWithIndex.map(d => (d._2, d._1)).toMap
740
+ // (colInd, matrixInd, colStart, colEnd, numRows)
741
+ val allColPtrs : Seq [(Int , Int , Int , Int , Int )] = matMap.flatMap { case (ind, mat) =>
742
+ mat match {
743
+ case spMat : SparseMatrix =>
744
+ val ptr = spMat.colPtrs
745
+ var colStart = 0
746
+ var j = 0
747
+ ptr.slice(1 , ptr.length).map { p =>
748
+ j += 1
749
+ val oldColStart = colStart
750
+ colStart = p
751
+ (j - 1 , ind, oldColStart, p, spMat.numRows)
752
+ }
753
+ case dnMat : DenseMatrix =>
754
+ val colSize = dnMat.numCols
755
+ var j = 0
756
+ val rowSize = dnMat.numRows
757
+ val ptr = new ArrayBuffer [(Int , Int , Int , Int , Int )](colSize)
758
+ var nnz = 0
759
+ val vals = dnMat.values
760
+ var colStart = 0
761
+ while (j < colSize) {
762
+ var i = j * rowSize
763
+ val indEnd = (j + 1 ) * rowSize
764
+ while (i < indEnd) {
765
+ if (vals(i) != 0.0 ) nnz += 1
766
+ i += 1
767
+ }
768
+ ptr.append((j, ind, colStart, nnz, dnMat.numRows))
769
+ j += 1
770
+ colStart = nnz
771
+ }
772
+ ptr
773
+ }
774
+ }.toSeq
704
775
val values = new ArrayBuffer [Double ](valsLength)
705
776
val rowInd = new ArrayBuffer [Int ](valsLength)
706
777
val newColPtrs = new Array [Int ](numCols)
@@ -712,31 +783,38 @@ object Matrices {
712
783
var startRow = 0
713
784
sortedPtrs.foreach { case (colIdx, matrixInd, colStart, colEnd, nRows) =>
714
785
val selectedMatrix = matMap(matrixInd)
715
- val selectedValues = selectedMatrix.values.slice(colStart, colEnd)
716
- val selectedRowIdx = selectedMatrix.rowIndices.slice(colStart, colEnd)
717
- val len = selectedValues.length
718
- newColPtrs(colIdx) += len
719
- var i = 0
720
- while (i < len) {
721
- values.append(selectedValues(i))
722
- rowInd.append(selectedRowIdx(i) + startRow)
723
- i += 1
786
+ selectedMatrix match {
787
+ case spMat : SparseMatrix =>
788
+ val selectedValues = spMat.values
789
+ val selectedRowIdx = spMat.rowIndices
790
+ val len = colEnd - colStart
791
+ newColPtrs(colIdx) += len
792
+ var i = colStart
793
+ while (i < colEnd) {
794
+ values.append(selectedValues(i))
795
+ rowInd.append(selectedRowIdx(i) + startRow)
796
+ i += 1
797
+ }
798
+ case dnMat : DenseMatrix =>
799
+ val selectedValues = dnMat.values
800
+ val len = colEnd - colStart
801
+ newColPtrs(colIdx) += len
802
+ val indStart = colIdx * nRows
803
+ var i = 0
804
+ while (i < nRows) {
805
+ val v = selectedValues(indStart + i)
806
+ if (v != 0 ) {
807
+ values.append(v)
808
+ rowInd.append(i + startRow)
809
+ }
810
+ i += 1
811
+ }
724
812
}
725
813
startRow += nRows
726
814
}
727
815
}
728
816
val adjustedPtrs = newColPtrs.scanLeft(0 )(_ + _)
729
817
new SparseMatrix (numRows, numCols, adjustedPtrs, rowInd.toArray, values.toArray)
730
- } else if (! isSparse && ! isDense) {
731
- throw new IllegalArgumentException (" The supplied matrices are neither in SparseMatrix or" +
732
- " DenseMatrix format!" )
733
- }else {
734
- val matData = matrices.zipWithIndex.flatMap { case (mat, ind) =>
735
- val values = mat.toArray
736
- for (j <- 0 until numCols) yield (j, ind,
737
- values.slice(j * mat.numRows, (j + 1 ) * mat.numRows))
738
- }.sortBy(x => (x._1, x._2))
739
- new DenseMatrix (numRows, numCols, matData.flatMap(_._3).toArray)
740
818
}
741
819
}
742
820
}
0 commit comments