@@ -638,9 +638,10 @@ class Matrix(object):
638
638
Represents a local matrix.
639
639
"""
640
640
641
- def __init__ (self , numRows , numCols ):
641
+ def __init__ (self , numRows , numCols , isTransposed = False ):
642
642
self .numRows = numRows
643
643
self .numCols = numCols
644
+ self .isTransposed = isTransposed
644
645
645
646
def toArray (self ):
646
647
"""
@@ -662,14 +663,15 @@ class DenseMatrix(Matrix):
662
663
"""
663
664
Column-major dense matrix.
664
665
"""
665
- def __init__ (self , numRows , numCols , values ):
666
- Matrix .__init__ (self , numRows , numCols )
666
+ def __init__ (self , numRows , numCols , values , isTransposed = False ):
667
+ Matrix .__init__ (self , numRows , numCols , isTransposed )
667
668
values = self ._convert_to_array (values , np .float64 )
668
669
assert len (values ) == numRows * numCols
669
670
self .values = values
670
671
671
672
def __reduce__ (self ):
672
- return DenseMatrix , (self .numRows , self .numCols , self .values .tostring ())
673
+ return DenseMatrix , (
674
+ self .numRows , self .numCols , self .values .tostring (), self .isTransposed )
673
675
674
676
def toArray (self ):
675
677
"""
@@ -680,15 +682,23 @@ def toArray(self):
680
682
array([[ 0., 2.],
681
683
[ 1., 3.]])
682
684
"""
683
- return self .values .reshape ((self .numRows , self .numCols ), order = 'F' )
685
+ if self .isTransposed :
686
+ return np .asfortranarray (
687
+ self .values .reshape ((self .numRows , self .numCols )))
688
+ else :
689
+ return self .values .reshape ((self .numRows , self .numCols ), order = 'F' )
684
690
685
691
def toSparse (self ):
686
692
"""Convert to SparseMatrix"""
687
- indices = np .nonzero (self .values )[0 ]
693
+ if self .isTransposed :
694
+ values = np .ravel (self .toArray (), order = 'F' )
695
+ else :
696
+ values = self .values
697
+ indices = np .nonzero (values )[0 ]
688
698
colCounts = np .bincount (indices // self .numRows )
689
699
colPtrs = np .cumsum (np .hstack (
690
700
(0 , colCounts , np .zeros (self .numCols - colCounts .size ))))
691
- values = self . values [indices ]
701
+ values = values [indices ]
692
702
rowIndices = indices % self .numRows
693
703
694
704
return SparseMatrix (self .numRows , self .numCols , colPtrs , rowIndices , values )
@@ -701,21 +711,28 @@ def __getitem__(self, indices):
701
711
if j >= self .numCols or j < 0 :
702
712
raise ValueError ("Column index %d is out of range [0, %d)"
703
713
% (j , self .numCols ))
704
- return self .values [i + j * self .numRows ]
714
+
715
+ if self .isTransposed :
716
+ return self .values [i * self .numCols + j ]
717
+ else :
718
+ return self .values [i + j * self .numRows ]
705
719
706
720
def __eq__ (self , other ):
707
- return (isinstance (other , DenseMatrix ) and
708
- self .numRows == other .numRows and
709
- self .numCols == other .numCols and
710
- all (self .values == other .values ))
721
+ if (not isinstance (other , DenseMatrix ) or
722
+ self .numRows != other .numRows or
723
+ self .numCols != other .numCols ):
724
+ return False
725
+
726
+ self_values = np .ravel (self .toArray (), order = 'F' )
727
+ other_values = np .ravel (other .toArray (), order = 'F' )
728
+ return all (self_values == other_values )
711
729
712
730
713
731
class SparseMatrix (Matrix ):
714
732
"""Sparse Matrix stored in CSC format."""
715
733
def __init__ (self , numRows , numCols , colPtrs , rowIndices , values ,
716
734
isTransposed = False ):
717
- Matrix .__init__ (self , numRows , numCols )
718
- self .isTransposed = isTransposed
735
+ Matrix .__init__ (self , numRows , numCols , isTransposed )
719
736
self .colPtrs = self ._convert_to_array (colPtrs , np .int32 )
720
737
self .rowIndices = self ._convert_to_array (rowIndices , np .int32 )
721
738
self .values = self ._convert_to_array (values , np .float64 )
@@ -777,8 +794,7 @@ def toArray(self):
777
794
return A
778
795
779
796
def toDense (self ):
780
- densevals = np .reshape (
781
- self .toArray (), (self .numRows * self .numCols ), order = 'F' )
797
+ densevals = np .ravel (self .toArray (), order = 'F' )
782
798
return DenseMatrix (self .numRows , self .numCols , densevals )
783
799
784
800
# TODO: More efficient implementation:
0 commit comments