@@ -52,17 +52,17 @@ private[mllib] class GridPartitioner(
52
52
* Returns the index of the partition the input coordinate belongs to.
53
53
*
54
54
* @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
55
- * multiplication.
55
+ * multiplication. k is ignored in computing partitions.
56
56
* @return The index of the partition, which the coordinate belongs to.
57
57
*/
58
58
override def getPartition (key : Any ): Int = {
59
59
key match {
60
60
case (i : Int , j : Int ) =>
61
61
getPartitionId(i, j)
62
- case (i : Int , j : Int , _) =>
62
+ case (i : Int , j : Int , _ : Int ) =>
63
63
getPartitionId(i, j)
64
64
case _ =>
65
- throw new IllegalArgumentException (s " Unrecognized key: $key" )
65
+ throw new IllegalArgumentException (s " Unrecognized key: $key. " )
66
66
}
67
67
}
68
68
@@ -73,7 +73,6 @@ private[mllib] class GridPartitioner(
73
73
i / rowsPerPart + j / colsPerPart * rowPartitions
74
74
}
75
75
76
- /** Checks whether the partitioners have the same characteristics */
77
76
override def equals (obj : Any ): Boolean = {
78
77
obj match {
79
78
case r : GridPartitioner =>
@@ -87,10 +86,12 @@ private[mllib] class GridPartitioner(
87
86
88
87
private [mllib] object GridPartitioner {
89
88
89
+ /** Creates a new [[GridPartitioner ]] instance. */
90
90
def apply (rows : Int , cols : Int , rowsPerPart : Int , colsPerPart : Int ): GridPartitioner = {
91
91
new GridPartitioner (rows, cols, rowsPerPart, colsPerPart)
92
92
}
93
93
94
+ /** Creates a new [[GridPartitioner ]] instance with the input suggested number of partitions. */
94
95
def apply (rows : Int , cols : Int , suggestedNumPartitions : Int ): GridPartitioner = {
95
96
require(suggestedNumPartitions > 0 )
96
97
val scale = 1.0 / math.sqrt(suggestedNumPartitions)
@@ -103,24 +104,25 @@ private[mllib] object GridPartitioner {
103
104
/**
104
105
* Represents a distributed matrix in blocks of local matrices.
105
106
*
106
- * @param rdd The RDD of SubMatrices (local matrices) that form this matrix
107
- * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
108
- * the number of rows will be calculated when `numRows` is invoked.
109
- * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
110
- * zero, the number of columns will be calculated when `numCols` is invoked.
107
+ * @param blocks The RDD of sub-matrix blocks (blockRowIndex, blockColIndex, sub-matrix) that form
108
+ * this distributed matrix.
111
109
* @param rowsPerBlock Number of rows that make up each block. The blocks forming the final
112
110
* rows are not required to have the given number of rows
113
111
* @param colsPerBlock Number of columns that make up each block. The blocks forming the final
114
112
* columns are not required to have the given number of columns
113
+ * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
114
+ * the number of rows will be calculated when `numRows` is invoked.
115
+ * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
116
+ * zero, the number of columns will be calculated when `numCols` is invoked.
115
117
*/
116
118
class BlockMatrix (
117
- val rdd : RDD [((Int , Int ), Matrix )],
118
- private var nRows : Long ,
119
- private var nCols : Long ,
119
+ val blocks : RDD [((Int , Int ), Matrix )],
120
120
val rowsPerBlock : Int ,
121
- val colsPerBlock : Int ) extends DistributedMatrix with Logging {
121
+ val colsPerBlock : Int ,
122
+ private var nRows : Long ,
123
+ private var nCols : Long ) extends DistributedMatrix with Logging {
122
124
123
- private type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
125
+ private type MatrixBlock = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), sub- matrix)
124
126
125
127
/**
126
128
* Alternate constructor for BlockMatrix without the input of the number of rows and columns.
@@ -135,45 +137,48 @@ class BlockMatrix(
135
137
rdd : RDD [((Int , Int ), Matrix )],
136
138
rowsPerBlock : Int ,
137
139
colsPerBlock : Int ) = {
138
- this (rdd, 0L , 0L , rowsPerBlock, colsPerBlock )
140
+ this (rdd, rowsPerBlock, colsPerBlock, 0L , 0L )
139
141
}
140
142
141
- private lazy val dims : (Long , Long ) = getDim
142
-
143
143
override def numRows (): Long = {
144
- if (nRows <= 0L ) nRows = dims._1
144
+ if (nRows <= 0L ) estimateDim()
145
145
nRows
146
146
}
147
147
148
148
override def numCols (): Long = {
149
- if (nCols <= 0L ) nCols = dims._2
149
+ if (nCols <= 0L ) estimateDim()
150
150
nCols
151
151
}
152
152
153
153
val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt
154
154
val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
155
155
156
156
private [mllib] var partitioner : GridPartitioner =
157
- GridPartitioner (numRowBlocks, numColBlocks, suggestedNumPartitions = rdd.partitions.size)
158
-
159
- /** Returns the dimensions of the matrix. */
160
- private def getDim : (Long , Long ) = {
161
- val (rows, cols) = rdd.map { case ((blockRowIndex, blockColIndex), mat) =>
162
- (blockRowIndex * rowsPerBlock + mat.numRows, blockColIndex * colsPerBlock + mat.numCols)
163
- }.reduce((x0, x1) => (math.max(x0._1, x1._1), math.max(x0._2, x1._2)))
164
-
165
- (math.max(rows, nRows), math.max(cols, nCols))
157
+ GridPartitioner (numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size)
158
+
159
+ /** Estimates the dimensions of the matrix. */
160
+ private def estimateDim (): Unit = {
161
+ val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
162
+ (blockRowIndex.toLong * rowsPerBlock + mat.numRows,
163
+ blockColIndex.toLong * colsPerBlock + mat.numCols)
164
+ }.reduce { (x0, x1) =>
165
+ (math.max(x0._1, x1._1), math.max(x0._2, x1._2))
166
+ }
167
+ if (nRows <= 0L ) nRows = rows
168
+ assert(rows <= nRows, s " The number of rows $rows is more than claimed $nRows. " )
169
+ if (nCols <= 0L ) nCols = cols
170
+ assert(cols <= nCols, s " The number of columns $cols is more than claimed $nCols. " )
166
171
}
167
172
168
- /** Cache the underlying RDD. */
169
- def cache (): BlockMatrix = {
170
- rdd .cache()
173
+ /** Caches the underlying RDD. */
174
+ def cache (): this . type = {
175
+ blocks .cache()
171
176
this
172
177
}
173
178
174
- /** Set the storage level for the underlying RDD . */
175
- def persist (storageLevel : StorageLevel ): BlockMatrix = {
176
- rdd .persist(storageLevel)
179
+ /** Persists the underlying RDD with the specified storage level . */
180
+ def persist (storageLevel : StorageLevel ): this . type = {
181
+ blocks .persist(storageLevel)
177
182
this
178
183
}
179
184
@@ -185,22 +190,22 @@ class BlockMatrix(
185
190
s " Int.MaxValue. Currently numCols: ${numCols()}" )
186
191
require(numRows() * numCols() < Int .MaxValue , " The length of the values array must be " +
187
192
s " less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}" )
188
- val nRows = numRows().toInt
189
- val nCols = numCols().toInt
190
- val mem = nRows * nCols / 125000
193
+ val m = numRows().toInt
194
+ val n = numCols().toInt
195
+ val mem = m * n / 125000
191
196
if (mem > 500 ) logWarning(s " Storing this matrix will require $mem MB of memory! " )
192
197
193
- val parts = rdd .collect()
194
- val values = new Array [Double ](nRows * nCols )
195
- parts .foreach { case ((blockRowIndex, blockColIndex), block ) =>
198
+ val localBlocks = blocks .collect()
199
+ val values = new Array [Double ](m * n )
200
+ localBlocks .foreach { case ((blockRowIndex, blockColIndex), submat ) =>
196
201
val rowOffset = blockRowIndex * rowsPerBlock
197
202
val colOffset = blockColIndex * colsPerBlock
198
- block .foreachActive { (i, j, v) =>
199
- val indexOffset = (j + colOffset) * nRows + rowOffset + i
203
+ submat .foreachActive { (i, j, v) =>
204
+ val indexOffset = (j + colOffset) * m + rowOffset + i
200
205
values(indexOffset) = v
201
206
}
202
207
}
203
- new DenseMatrix (nRows, nCols , values)
208
+ new DenseMatrix (m, n , values)
204
209
}
205
210
206
211
/** Collects data and assembles a local dense breeze matrix (for test only). */
0 commit comments