Skip to content

Commit d75496b

Browse files
committed
[SPARK-3701][MLLIB] update python linalg api and small fixes
1. doc updates 2. simple checks on vector dimensions 3. use column major for matrices davies jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#2548 from mengxr/mllib-py-clean and squashes the following commits: 6dce2df [Xiangrui Meng] address comments 116b5db [Xiangrui Meng] use np.dot instead of array.dot 75f2fcc [Xiangrui Meng] fix python style fefce00 [Xiangrui Meng] better check of vector size with more tests 067ef71 [Xiangrui Meng] majored -> major ef853f9 [Xiangrui Meng] update python linalg api and small fixes
1 parent 6c696d7 commit d75496b

File tree

2 files changed

+125
-33
lines changed

2 files changed

+125
-33
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ sealed trait Matrix extends Serializable {
8585
}
8686

8787
/**
88-
* Column-majored dense matrix.
88+
* Column-major dense matrix.
8989
* The entry values are stored in a single array of doubles with columns listed in sequence.
9090
* For example, the following matrix
9191
* {{{
@@ -128,7 +128,7 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
128128
}
129129

130130
/**
131-
* Column-majored sparse matrix.
131+
* Column-major sparse matrix.
132132
* The entry values are stored in Compressed Sparse Column (CSC) format.
133133
* For example, the following matrix
134134
* {{{
@@ -207,7 +207,7 @@ class SparseMatrix(
207207
object Matrices {
208208

209209
/**
210-
* Creates a column-majored dense matrix.
210+
* Creates a column-major dense matrix.
211211
*
212212
* @param numRows number of rows
213213
* @param numCols number of columns
@@ -218,7 +218,7 @@ object Matrices {
218218
}
219219

220220
/**
221-
* Creates a column-majored sparse matrix in Compressed Sparse Column (CSC) format.
221+
* Creates a column-major sparse matrix in Compressed Sparse Column (CSC) format.
222222
*
223223
* @param numRows number of rows
224224
* @param numCols number of columns

python/pyspark/mllib/linalg.py

Lines changed: 121 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,41 @@ def _convert_to_vector(l):
6363
raise TypeError("Cannot convert type %s into Vector" % type(l))
6464

6565

66+
def _vector_size(v):
67+
"""
68+
Returns the size of the vector.
69+
70+
>>> _vector_size([1., 2., 3.])
71+
3
72+
>>> _vector_size((1., 2., 3.))
73+
3
74+
>>> _vector_size(array.array('d', [1., 2., 3.]))
75+
3
76+
>>> _vector_size(np.zeros(3))
77+
3
78+
>>> _vector_size(np.zeros((3, 1)))
79+
3
80+
>>> _vector_size(np.zeros((1, 3)))
81+
Traceback (most recent call last):
82+
...
83+
ValueError: Cannot treat an ndarray of shape (1, 3) as a vector
84+
"""
85+
if isinstance(v, Vector):
86+
return len(v)
87+
elif type(v) in (array.array, list, tuple):
88+
return len(v)
89+
elif type(v) == np.ndarray:
90+
if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1):
91+
return len(v)
92+
else:
93+
raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape))
94+
elif _have_scipy and scipy.sparse.issparse(v):
95+
assert v.shape[1] == 1, "Expected column vector"
96+
return v.shape[0]
97+
else:
98+
raise TypeError("Cannot treat type %s as a vector" % type(v))
99+
100+
66101
class Vector(object):
67102
"""
68103
Abstract class for DenseVector and SparseVector
@@ -76,6 +111,9 @@ def toArray(self):
76111

77112

78113
class DenseVector(Vector):
114+
"""
115+
A dense vector represented by a value array.
116+
"""
79117
def __init__(self, ar):
80118
if not isinstance(ar, array.array):
81119
ar = array.array('d', ar)
@@ -100,15 +138,31 @@ def dot(self, other):
100138
5.0
101139
>>> dense.dot(np.array(range(1, 3)))
102140
5.0
141+
>>> dense.dot([1.,])
142+
Traceback (most recent call last):
143+
...
144+
AssertionError: dimension mismatch
145+
>>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F'))
146+
array([ 5., 11.])
147+
>>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F'))
148+
Traceback (most recent call last):
149+
...
150+
AssertionError: dimension mismatch
103151
"""
104-
if isinstance(other, SparseVector):
105-
return other.dot(self)
152+
if type(other) == np.ndarray and other.ndim > 1:
153+
assert len(self) == other.shape[0], "dimension mismatch"
154+
return np.dot(self.toArray(), other)
106155
elif _have_scipy and scipy.sparse.issparse(other):
107-
return other.transpose().dot(self.toArray())[0]
108-
elif isinstance(other, Vector):
109-
return np.dot(self.toArray(), other.toArray())
156+
assert len(self) == other.shape[0], "dimension mismatch"
157+
return other.transpose().dot(self.toArray())
110158
else:
111-
return np.dot(self.toArray(), other)
159+
assert len(self) == _vector_size(other), "dimension mismatch"
160+
if isinstance(other, SparseVector):
161+
return other.dot(self)
162+
elif isinstance(other, Vector):
163+
return np.dot(self.toArray(), other.toArray())
164+
else:
165+
return np.dot(self.toArray(), other)
112166

113167
def squared_distance(self, other):
114168
"""
@@ -126,7 +180,16 @@ def squared_distance(self, other):
126180
>>> sparse1 = SparseVector(2, [0, 1], [2., 1.])
127181
>>> dense1.squared_distance(sparse1)
128182
2.0
183+
>>> dense1.squared_distance([1.,])
184+
Traceback (most recent call last):
185+
...
186+
AssertionError: dimension mismatch
187+
>>> dense1.squared_distance(SparseVector(1, [0,], [1.,]))
188+
Traceback (most recent call last):
189+
...
190+
AssertionError: dimension mismatch
129191
"""
192+
assert len(self) == _vector_size(other), "dimension mismatch"
130193
if isinstance(other, SparseVector):
131194
return other.squared_distance(self)
132195
elif _have_scipy and scipy.sparse.issparse(other):
@@ -165,12 +228,10 @@ def __getattr__(self, item):
165228

166229

167230
class SparseVector(Vector):
168-
169231
"""
170232
A simple sparse vector class for passing data to MLlib. Users may
171233
alternatively pass SciPy's {scipy.sparse} data types.
172234
"""
173-
174235
def __init__(self, size, *args):
175236
"""
176237
Create a sparse vector, using either a dictionary, a list of
@@ -222,20 +283,33 @@ def dot(self, other):
222283
0.0
223284
>>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))
224285
array([ 22., 22.])
286+
>>> a.dot([1., 2., 3.])
287+
Traceback (most recent call last):
288+
...
289+
AssertionError: dimension mismatch
290+
>>> a.dot(np.array([1., 2.]))
291+
Traceback (most recent call last):
292+
...
293+
AssertionError: dimension mismatch
294+
>>> a.dot(DenseVector([1., 2.]))
295+
Traceback (most recent call last):
296+
...
297+
AssertionError: dimension mismatch
298+
>>> a.dot(np.zeros((3, 2)))
299+
Traceback (most recent call last):
300+
...
301+
AssertionError: dimension mismatch
225302
"""
226303
if type(other) == np.ndarray:
227-
if other.ndim == 1:
228-
result = 0.0
229-
for i in xrange(len(self.indices)):
230-
result += self.values[i] * other[self.indices[i]]
231-
return result
232-
elif other.ndim == 2:
304+
if other.ndim == 2:
233305
results = [self.dot(other[:, i]) for i in xrange(other.shape[1])]
234306
return np.array(results)
235-
else:
236-
raise Exception("Cannot call dot with %d-dimensional array" % other.ndim)
307+
elif other.ndim > 2:
308+
raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim)
309+
310+
assert len(self) == _vector_size(other), "dimension mismatch"
237311

238-
elif type(other) in (array.array, DenseVector):
312+
if type(other) in (np.ndarray, array.array, DenseVector):
239313
result = 0.0
240314
for i in xrange(len(self.indices)):
241315
result += self.values[i] * other[self.indices[i]]
@@ -254,6 +328,7 @@ def dot(self, other):
254328
else:
255329
j += 1
256330
return result
331+
257332
else:
258333
return self.dot(_convert_to_vector(other))
259334

@@ -273,7 +348,16 @@ def squared_distance(self, other):
273348
30.0
274349
>>> b.squared_distance(a)
275350
30.0
351+
>>> b.squared_distance([1., 2.])
352+
Traceback (most recent call last):
353+
...
354+
AssertionError: dimension mismatch
355+
>>> b.squared_distance(SparseVector(3, [1,], [1.0,]))
356+
Traceback (most recent call last):
357+
...
358+
AssertionError: dimension mismatch
276359
"""
360+
assert len(self) == _vector_size(other), "dimension mismatch"
277361
if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
278362
if type(other) is np.array and other.ndim != 1:
279363
raise Exception("Cannot call squared_distance with %d-dimensional array" %
@@ -348,7 +432,6 @@ def __eq__(self, other):
348432
>>> v1 != v2
349433
False
350434
"""
351-
352435
return (isinstance(other, self.__class__)
353436
and other.size == self.size
354437
and other.indices == self.indices
@@ -414,23 +497,32 @@ def stringify(vector):
414497

415498

416499
class Matrix(object):
417-
""" the Matrix """
418-
def __init__(self, nRow, nCol):
419-
self.nRow = nRow
420-
self.nCol = nCol
500+
"""
501+
Represents a local matrix.
502+
"""
503+
504+
def __init__(self, numRows, numCols):
505+
self.numRows = numRows
506+
self.numCols = numCols
421507

422508
def toArray(self):
509+
"""
510+
Returns its elements in a NumPy ndarray.
511+
"""
423512
raise NotImplementedError
424513

425514

426515
class DenseMatrix(Matrix):
427-
def __init__(self, nRow, nCol, values):
428-
Matrix.__init__(self, nRow, nCol)
429-
assert len(values) == nRow * nCol
516+
"""
517+
Column-major dense matrix.
518+
"""
519+
def __init__(self, numRows, numCols, values):
520+
Matrix.__init__(self, numRows, numCols)
521+
assert len(values) == numRows * numCols
430522
self.values = values
431523

432524
def __reduce__(self):
433-
return DenseMatrix, (self.nRow, self.nCol, self.values)
525+
return DenseMatrix, (self.numRows, self.numCols, self.values)
434526

435527
def toArray(self):
436528
"""
@@ -439,10 +531,10 @@ def toArray(self):
439531
>>> arr = array.array('d', [float(i) for i in range(4)])
440532
>>> m = DenseMatrix(2, 2, arr)
441533
>>> m.toArray()
442-
array([[ 0., 1.],
443-
[ 2., 3.]])
534+
array([[ 0., 2.],
535+
[ 1., 3.]])
444536
"""
445-
return np.ndarray((self.nRow, self.nCol), np.float64, buffer=self.values.tostring())
537+
return np.reshape(self.values, (self.numRows, self.numCols), order='F')
446538

447539

448540
def _test():

0 commit comments

Comments
 (0)