@@ -32,10 +32,10 @@ import org.apache.spark.util.Utils
32
32
* @param blockIdCol The column index of this block
33
33
* @param mat The underlying local matrix
34
34
*/
35
- case class BlockPartition (blockIdRow : Int , blockIdCol : Int , mat : DenseMatrix ) extends Serializable
35
+ case class SubMatrix (blockIdRow : Int , blockIdCol : Int , mat : DenseMatrix ) extends Serializable
36
36
37
37
/**
38
- * Information about the BlockMatrix maintained on the driver
38
+ * Information of the submatrices of the BlockMatrix maintained on the driver
39
39
*
40
40
* @param partitionId The id of the partition the block is found in
41
41
* @param blockIdRow The row index of this block
@@ -45,7 +45,7 @@ case class BlockPartition(blockIdRow: Int, blockIdCol: Int, mat: DenseMatrix) ex
45
45
* @param startCol The starting column index with respect to the distributed BlockMatrix
46
46
* @param numCols The number of columns in this block
47
47
*/
48
- case class BlockPartitionInfo (
48
+ case class SubMatrixInfo (
49
49
partitionId : Int ,
50
50
blockIdRow : Int ,
51
51
blockIdCol : Int ,
@@ -67,6 +67,13 @@ abstract class BlockMatrixPartitioner(
67
67
val colPerBlock : Int ) extends Partitioner {
68
68
val name : String
69
69
70
+ /**
71
+ * Returns the index of the partition the SubMatrix belongs to.
72
+ *
73
+ * @param key The key for the SubMatrix. Can be its row index, column index or position in the
74
+ * grid.
75
+ * @return The index of the partition, which the SubMatrix belongs to.
76
+ */
70
77
override def getPartition (key : Any ): Int = {
71
78
Utils .nonNegativeMod(key.asInstanceOf [Int ], numPartitions)
72
79
}
@@ -91,6 +98,7 @@ class GridPartitioner(
91
98
92
99
override val numPartitions = numRowBlocks * numColBlocks
93
100
101
+ /** Checks whether the partitioners have the same characteristics */
94
102
override def equals (obj : Any ): Boolean = {
95
103
obj match {
96
104
case r : GridPartitioner =>
@@ -118,6 +126,7 @@ class RowBasedPartitioner(
118
126
119
127
override val name = " row"
120
128
129
+ /** Checks whether the partitioners have the same characteristics */
121
130
override def equals (obj : Any ): Boolean = {
122
131
obj match {
123
132
case r : RowBasedPartitioner =>
@@ -145,6 +154,7 @@ class ColumnBasedPartitioner(
145
154
146
155
override val name = " column"
147
156
157
+ /** Checks whether the partitioners have the same characteristics */
148
158
override def equals (obj : Any ): Boolean = {
149
159
obj match {
150
160
case p : ColumnBasedPartitioner =>
@@ -163,19 +173,19 @@ class ColumnBasedPartitioner(
163
173
*
164
174
* @param numRowBlocks Number of blocks that form the rows of this matrix
165
175
* @param numColBlocks Number of blocks that form the columns of this matrix
166
- * @param rdd The RDD of BlockPartitions (local matrices) that form this matrix
167
- * @param partitioner A partitioner that specifies how BlockPartitions are stored in the cluster
176
+ * @param rdd The RDD of SubMatrixs (local matrices) that form this matrix
177
+ * @param partitioner A partitioner that specifies how SubMatrixs are stored in the cluster
168
178
*/
169
179
class BlockMatrix (
170
180
val numRowBlocks : Int ,
171
181
val numColBlocks : Int ,
172
- val rdd : RDD [BlockPartition ],
182
+ val rdd : RDD [SubMatrix ],
173
183
val partitioner : BlockMatrixPartitioner ) extends DistributedMatrix with Logging {
174
184
175
185
// A key-value pair RDD is required to partition properly
176
- private var matrixRDD : RDD [(Int , BlockPartition )] = keyBy()
186
+ private var matrixRDD : RDD [(Int , SubMatrix )] = keyBy()
177
187
178
- @ transient var blockInfo_ : Map [(Int , Int ), BlockPartitionInfo ] = null
188
+ @ transient var blockInfo_ : Map [(Int , Int ), SubMatrixInfo ] = null
179
189
180
190
private lazy val dims : (Long , Long ) = getDim
181
191
@@ -184,40 +194,36 @@ class BlockMatrix(
184
194
185
195
if (partitioner.name.equals(" column" )) {
186
196
require(numColBlocks == partitioner.numPartitions, " The number of column blocks should match" +
187
- " the number of partitions of the column partitioner." )
197
+ s " the number of partitions of the column partitioner. numColBlocks: $numColBlocks, " +
198
+ s " partitioner.numPartitions: ${partitioner.numPartitions}" )
188
199
} else if (partitioner.name.equals(" row" )) {
189
200
require(numRowBlocks == partitioner.numPartitions, " The number of row blocks should match" +
190
- " the number of partitions of the row partitioner." )
201
+ s " the number of partitions of the row partitioner. numRowBlocks: $numRowBlocks, " +
202
+ s " partitioner.numPartitions: ${partitioner.numPartitions}" )
191
203
} else if (partitioner.name.equals(" grid" )) {
192
204
require(numRowBlocks * numColBlocks == partitioner.numPartitions, " The number of blocks " +
193
- " should match the number of partitions of the grid partitioner." )
205
+ s " should match the number of partitions of the grid partitioner. numRowBlocks * " +
206
+ s " numColBlocks: ${numRowBlocks * numColBlocks}, " +
207
+ s " partitioner.numPartitions: ${partitioner.numPartitions}" )
194
208
} else {
195
209
throw new IllegalArgumentException (" Unrecognized partitioner." )
196
210
}
197
211
198
- /* Returns the dimensions of the matrix. */
212
+ /** Returns the dimensions of the matrix. */
199
213
def getDim : (Long , Long ) = {
200
214
val bi = getBlockInfo
201
215
val xDim = bi.map { x =>
202
216
(x._1._1, x._2.numRows.toLong)
203
- }.groupBy(x => x._1).values.map { x =>
204
- x.head._2.toLong
205
- }.reduceLeft {
206
- _ + _
207
- }
217
+ }.groupBy(x => x._1).values.map(_.head._2.toLong).reduceLeft(_ + _)
208
218
209
219
val yDim = bi.map { x =>
210
220
(x._1._2, x._2.numCols.toLong)
211
- }.groupBy(x => x._1).values.map { x =>
212
- x.head._2.toLong
213
- }.reduceLeft {
214
- _ + _
215
- }
221
+ }.groupBy(x => x._1).values.map(_.head._2.toLong).reduceLeft(_ + _)
216
222
217
223
(xDim, yDim)
218
224
}
219
225
220
- /* Calculates the information for each block and collects it on the driver */
226
+ /** Calculates the information for each block and collects it on the driver */
221
227
private def calculateBlockInfo (): Unit = {
222
228
// collect may cause akka frameSize errors
223
229
val blockStartRowColsParts = matrixRDD.mapPartitionsWithIndex { case (partId, iter) =>
@@ -243,38 +249,38 @@ class BlockMatrix(
243
249
}.toMap
244
250
245
251
blockInfo_ = blockStartRowCols.map{ case ((rowId, colId), (partId, numRow, numCol)) =>
246
- ((rowId, colId), new BlockPartitionInfo (partId, rowId, colId, cumulativeRowSum(rowId),
252
+ ((rowId, colId), new SubMatrixInfo (partId, rowId, colId, cumulativeRowSum(rowId),
247
253
numRow, cumulativeColSum(colId), numCol))
248
254
}.toMap
249
255
}
250
256
251
- /* Returns a map of the information of the blocks that form the distributed matrix. */
252
- def getBlockInfo : Map [(Int , Int ), BlockPartitionInfo ] = {
257
+ /** Returns a map of the information of the blocks that form the distributed matrix. */
258
+ def getBlockInfo : Map [(Int , Int ), SubMatrixInfo ] = {
253
259
if (blockInfo_ == null ) {
254
260
calculateBlockInfo()
255
261
}
256
262
blockInfo_
257
263
}
258
264
259
- /* Returns the Frobenius Norm of the matrix */
265
+ /** Returns the Frobenius Norm of the matrix */
260
266
def normFro (): Double = {
261
267
math.sqrt(rdd.map(lm => lm.mat.values.map(x => math.pow(x, 2 )).sum).reduce(_ + _))
262
268
}
263
269
264
- /* Cache the underlying RDD. */
270
+ /** Cache the underlying RDD. */
265
271
def cache (): DistributedMatrix = {
266
272
matrixRDD.cache()
267
273
this
268
274
}
269
275
270
- /* Set the storage level for the underlying RDD. */
276
+ /** Set the storage level for the underlying RDD. */
271
277
def persist (storageLevel : StorageLevel ): DistributedMatrix = {
272
278
matrixRDD.persist(storageLevel)
273
279
this
274
280
}
275
281
276
- /* Add a key to the underlying rdd for partitioning and joins. */
277
- private def keyBy (part : BlockMatrixPartitioner = partitioner): RDD [(Int , BlockPartition )] = {
282
+ /** Add a key to the underlying rdd for partitioning and joins. */
283
+ private def keyBy (part : BlockMatrixPartitioner = partitioner): RDD [(Int , SubMatrix )] = {
278
284
rdd.map { block =>
279
285
part match {
280
286
case r : RowBasedPartitioner => (block.blockIdRow, block)
@@ -296,7 +302,7 @@ class BlockMatrix(
296
302
this
297
303
}
298
304
299
- /* Collect the distributed matrix on the driver. */
305
+ /** Collect the distributed matrix on the driver. */
300
306
def collect (): DenseMatrix = {
301
307
val parts = rdd.map(x => ((x.blockIdRow, x.blockIdCol), x.mat)).
302
308
collect().sortBy(x => (x._1._2, x._1._1))
@@ -324,6 +330,7 @@ class BlockMatrix(
324
330
new DenseMatrix (nRows, nCols, values)
325
331
}
326
332
333
+ /** Collects data and assembles a local dense breeze matrix (for test only). */
327
334
private [mllib] def toBreeze (): BDM [Double ] = {
328
335
val localMat = collect()
329
336
new BDM [Double ](localMat.numRows, localMat.numCols, localMat.values)
0 commit comments