@@ -155,13 +155,11 @@ class ColumnBasedPartitioner(
155
155
* @param numRowBlocks Number of blocks that form the rows of this matrix
156
156
* @param numColBlocks Number of blocks that form the columns of this matrix
157
157
* @param rdd The RDD of SubMatrices (local matrices) that form this matrix
158
- * @param partitioner A partitioner that specifies how SubMatrices are stored in the cluster
159
158
*/
160
159
class BlockMatrix (
161
160
val numRowBlocks : Int ,
162
161
val numColBlocks : Int ,
163
- val rdd : RDD [SubMatrix ],
164
- val partitioner : BlockMatrixPartitioner ) extends DistributedMatrix with Logging {
162
+ val rdd : RDD [SubMatrix ]) extends DistributedMatrix with Logging {
165
163
166
164
/**
167
165
* Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid
@@ -170,11 +168,31 @@ class BlockMatrix(
170
168
* @param numRowBlocks Number of blocks that form the rows of this matrix
171
169
* @param numColBlocks Number of blocks that form the columns of this matrix
172
170
* @param rdd The RDD of SubMatrices (local matrices) that form this matrix
171
+ * @param partitioner A partitioner that specifies how SubMatrices are stored in the cluster
173
172
*/
174
- def this (numRowBlocks : Int , numColBlocks : Int , rdd : RDD [SubMatrix ]) = {
175
- this (numRowBlocks, numColBlocks, rdd, new GridPartitioner (numRowBlocks, numColBlocks,
176
- rdd.first().mat.numRows, rdd.first().mat.numCols))
173
+ def this (
174
+ numRowBlocks : Int ,
175
+ numColBlocks : Int ,
176
+ rdd : RDD [SubMatrix ],
177
+ partitioner : BlockMatrixPartitioner ) = {
178
+ this (numRowBlocks, numColBlocks, rdd)
179
+ setPartitioner(partitioner)
177
180
}
181
+
182
+ private [mllib] var partitioner : BlockMatrixPartitioner = {
183
+ val firstSubMatrix = rdd.first().mat
184
+ new GridPartitioner (numRowBlocks, numColBlocks,
185
+ firstSubMatrix.numRows, firstSubMatrix.numCols)
186
+ }
187
+
188
+ /**
189
+ * Set the partitioner for the matrix. For internal use only. Users should use `repartition`.
190
+ * @param part A partitioner that specifies how SubMatrices are stored in the cluster
191
+ */
192
+ private def setPartitioner (part : BlockMatrixPartitioner ): Unit = {
193
+ partitioner = part
194
+ }
195
+
178
196
// A key-value pair RDD is required to partition properly
179
197
private var matrixRDD : RDD [(Int , SubMatrix )] = keyBy()
180
198
@@ -259,8 +277,9 @@ class BlockMatrix(
259
277
* @param part The partitioner to partition by
260
278
* @return The repartitioned BlockMatrix
261
279
*/
262
- def repartition (part : BlockMatrixPartitioner = partitioner ): DistributedMatrix = {
280
+ def repartition (part : BlockMatrixPartitioner ): DistributedMatrix = {
263
281
matrixRDD = keyBy(part)
282
+ setPartitioner(part)
264
283
this
265
284
}
266
285
0 commit comments