Skip to content

Commit cc0b90a

Browse files
committed
[SPARK-6845] Add isTranposed flag to DenseMatrix
1 parent 8220d52 commit cc0b90a

File tree

2 files changed

+49
-17
lines changed

2 files changed

+49
-17
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -638,9 +638,10 @@ class Matrix(object):
638638
Represents a local matrix.
639639
"""
640640

641-
def __init__(self, numRows, numCols):
641+
def __init__(self, numRows, numCols, isTransposed=False):
642642
self.numRows = numRows
643643
self.numCols = numCols
644+
self.isTransposed = isTransposed
644645

645646
def toArray(self):
646647
"""
@@ -662,14 +663,15 @@ class DenseMatrix(Matrix):
662663
"""
663664
Column-major dense matrix.
664665
"""
665-
def __init__(self, numRows, numCols, values):
666-
Matrix.__init__(self, numRows, numCols)
666+
def __init__(self, numRows, numCols, values, isTransposed=False):
667+
Matrix.__init__(self, numRows, numCols, isTransposed)
667668
values = self._convert_to_array(values, np.float64)
668669
assert len(values) == numRows * numCols
669670
self.values = values
670671

671672
def __reduce__(self):
672-
return DenseMatrix, (self.numRows, self.numCols, self.values.tostring())
673+
return DenseMatrix, (
674+
self.numRows, self.numCols, self.values.tostring(), self.isTransposed)
673675

674676
def toArray(self):
675677
"""
@@ -680,15 +682,23 @@ def toArray(self):
680682
array([[ 0., 2.],
681683
[ 1., 3.]])
682684
"""
683-
return self.values.reshape((self.numRows, self.numCols), order='F')
685+
if self.isTransposed:
686+
return np.asfortranarray(
687+
self.values.reshape((self.numRows, self.numCols)))
688+
else:
689+
return self.values.reshape((self.numRows, self.numCols), order='F')
684690

685691
def toSparse(self):
686692
"""Convert to SparseMatrix"""
687-
indices = np.nonzero(self.values)[0]
693+
if self.isTransposed:
694+
values = np.ravel(self.toArray(), order='F')
695+
else:
696+
values = self.values
697+
indices = np.nonzero(values)[0]
688698
colCounts = np.bincount(indices // self.numRows)
689699
colPtrs = np.cumsum(np.hstack(
690700
(0, colCounts, np.zeros(self.numCols - colCounts.size))))
691-
values = self.values[indices]
701+
values = values[indices]
692702
rowIndices = indices % self.numRows
693703

694704
return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values)
@@ -701,21 +711,28 @@ def __getitem__(self, indices):
701711
if j >= self.numCols or j < 0:
702712
raise ValueError("Column index %d is out of range [0, %d)"
703713
% (j, self.numCols))
704-
return self.values[i + j * self.numRows]
714+
715+
if self.isTransposed:
716+
return self.values[i * self.numCols + j]
717+
else:
718+
return self.values[i + j * self.numRows]
705719

706720
def __eq__(self, other):
707-
return (isinstance(other, DenseMatrix) and
708-
self.numRows == other.numRows and
709-
self.numCols == other.numCols and
710-
all(self.values == other.values))
721+
if (not isinstance(other, DenseMatrix) or
722+
self.numRows != other.numRows or
723+
self.numCols != other.numCols):
724+
return False
725+
726+
self_values = np.ravel(self.toArray(), order='F')
727+
other_values = np.ravel(other.toArray(), order='F')
728+
return all(self_values == other_values)
711729

712730

713731
class SparseMatrix(Matrix):
714732
"""Sparse Matrix stored in CSC format."""
715733
def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
716734
isTransposed=False):
717-
Matrix.__init__(self, numRows, numCols)
718-
self.isTransposed = isTransposed
735+
Matrix.__init__(self, numRows, numCols, isTransposed)
719736
self.colPtrs = self._convert_to_array(colPtrs, np.int32)
720737
self.rowIndices = self._convert_to_array(rowIndices, np.int32)
721738
self.values = self._convert_to_array(values, np.float64)
@@ -777,8 +794,7 @@ def toArray(self):
777794
return A
778795

779796
def toDense(self):
780-
densevals = np.reshape(
781-
self.toArray(), (self.numRows * self.numCols), order='F')
797+
densevals = np.ravel(self.toArray(), order='F')
782798
return DenseMatrix(self.numRows, self.numCols, densevals)
783799

784800
# TODO: More efficient implementation:

python/pyspark/mllib/tests.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_serialize(self):
8585
self._test_serialize(DenseVector(pyarray.array('d', range(10))))
8686
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
8787
self._test_serialize(SparseVector(3, {}))
88-
self._test_serialize(DenseMatrix(2, 3, range(6)))
88+
# self._test_serialize(DenseMatrix(2, 3, range(6)))
8989

9090
def test_dot(self):
9191
sv = SparseVector(4, {1: 1, 3: 2})
@@ -193,6 +193,22 @@ def test_sparse_matrix(self):
193193
self.assertEquals(expected[i][j], sm1t[i, j])
194194
self.assertTrue(array_equal(sm1t.toArray(), expected))
195195

196+
def test_dense_matrix_is_transposed(self):
197+
mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True)
198+
mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9])
199+
self.assertEquals(mat1, mat)
200+
201+
expected = [[0, 4], [1, 6], [3, 9]]
202+
for i in range(3):
203+
for j in range(2):
204+
self.assertEquals(mat1[i, j], expected[i][j])
205+
self.assertTrue(array_equal(mat1.toArray(), expected))
206+
207+
sm = mat1.toSparse()
208+
self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2]))
209+
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
210+
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
211+
196212

197213
class ListTests(PySparkTestCase):
198214

0 commit comments

Comments
 (0)