Skip to content

Commit 589fbb6

Browse files
author
Burak Yavuz
committed
[SPARK-3974] Code review feedback addressed
1 parent aa8f086 commit 589fbb6

File tree

1 file changed

+39
-32
lines changed

1 file changed

+39
-32
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ import org.apache.spark.util.Utils
3232
* @param blockIdCol The column index of this block
3333
* @param mat The underlying local matrix
3434
*/
35-
case class BlockPartition(blockIdRow: Int, blockIdCol: Int, mat: DenseMatrix) extends Serializable
35+
case class SubMatrix(blockIdRow: Int, blockIdCol: Int, mat: DenseMatrix) extends Serializable
3636

3737
/**
38-
* Information about the BlockMatrix maintained on the driver
38+
* Information of the submatrices of the BlockMatrix maintained on the driver
3939
*
4040
* @param partitionId The id of the partition the block is found in
4141
* @param blockIdRow The row index of this block
@@ -45,7 +45,7 @@ case class BlockPartition(blockIdRow: Int, blockIdCol: Int, mat: DenseMatrix) ex
4545
* @param startCol The starting column index with respect to the distributed BlockMatrix
4646
* @param numCols The number of columns in this block
4747
*/
48-
case class BlockPartitionInfo(
48+
case class SubMatrixInfo(
4949
partitionId: Int,
5050
blockIdRow: Int,
5151
blockIdCol: Int,
@@ -67,6 +67,13 @@ abstract class BlockMatrixPartitioner(
6767
val colPerBlock: Int) extends Partitioner {
6868
val name: String
6969

70+
/**
71+
* Returns the index of the partition the SubMatrix belongs to.
72+
*
73+
* @param key The key for the SubMatrix. Can be its row index, column index or position in the
74+
* grid.
75+
* @return The index of the partition, which the SubMatrix belongs to.
76+
*/
7077
override def getPartition(key: Any): Int = {
7178
Utils.nonNegativeMod(key.asInstanceOf[Int], numPartitions)
7279
}
@@ -91,6 +98,7 @@ class GridPartitioner(
9198

9299
override val numPartitions = numRowBlocks * numColBlocks
93100

101+
/** Checks whether the partitioners have the same characteristics */
94102
override def equals(obj: Any): Boolean = {
95103
obj match {
96104
case r: GridPartitioner =>
@@ -118,6 +126,7 @@ class RowBasedPartitioner(
118126

119127
override val name = "row"
120128

129+
/** Checks whether the partitioners have the same characteristics */
121130
override def equals(obj: Any): Boolean = {
122131
obj match {
123132
case r: RowBasedPartitioner =>
@@ -145,6 +154,7 @@ class ColumnBasedPartitioner(
145154

146155
override val name = "column"
147156

157+
/** Checks whether the partitioners have the same characteristics */
148158
override def equals(obj: Any): Boolean = {
149159
obj match {
150160
case p: ColumnBasedPartitioner =>
@@ -163,19 +173,19 @@ class ColumnBasedPartitioner(
163173
*
164174
* @param numRowBlocks Number of blocks that form the rows of this matrix
165175
* @param numColBlocks Number of blocks that form the columns of this matrix
166-
* @param rdd The RDD of BlockPartitions (local matrices) that form this matrix
167-
* @param partitioner A partitioner that specifies how BlockPartitions are stored in the cluster
176+
* @param rdd The RDD of SubMatrixs (local matrices) that form this matrix
177+
* @param partitioner A partitioner that specifies how SubMatrixs are stored in the cluster
168178
*/
169179
class BlockMatrix(
170180
val numRowBlocks: Int,
171181
val numColBlocks: Int,
172-
val rdd: RDD[BlockPartition],
182+
val rdd: RDD[SubMatrix],
173183
val partitioner: BlockMatrixPartitioner) extends DistributedMatrix with Logging {
174184

175185
// A key-value pair RDD is required to partition properly
176-
private var matrixRDD: RDD[(Int, BlockPartition)] = keyBy()
186+
private var matrixRDD: RDD[(Int, SubMatrix)] = keyBy()
177187

178-
@transient var blockInfo_ : Map[(Int, Int), BlockPartitionInfo] = null
188+
@transient var blockInfo_ : Map[(Int, Int), SubMatrixInfo] = null
179189

180190
private lazy val dims: (Long, Long) = getDim
181191

@@ -184,40 +194,36 @@ class BlockMatrix(
184194

185195
if (partitioner.name.equals("column")) {
186196
require(numColBlocks == partitioner.numPartitions, "The number of column blocks should match" +
187-
" the number of partitions of the column partitioner.")
197+
s" the number of partitions of the column partitioner. numColBlocks: $numColBlocks, " +
198+
s"partitioner.numPartitions: ${partitioner.numPartitions}")
188199
} else if (partitioner.name.equals("row")) {
189200
require(numRowBlocks == partitioner.numPartitions, "The number of row blocks should match" +
190-
" the number of partitions of the row partitioner.")
201+
s" the number of partitions of the row partitioner. numRowBlocks: $numRowBlocks, " +
202+
s"partitioner.numPartitions: ${partitioner.numPartitions}")
191203
} else if (partitioner.name.equals("grid")) {
192204
require(numRowBlocks * numColBlocks == partitioner.numPartitions, "The number of blocks " +
193-
"should match the number of partitions of the grid partitioner.")
205+
s"should match the number of partitions of the grid partitioner. numRowBlocks * " +
206+
s"numColBlocks: ${numRowBlocks * numColBlocks}, " +
207+
s"partitioner.numPartitions: ${partitioner.numPartitions}")
194208
} else {
195209
throw new IllegalArgumentException("Unrecognized partitioner.")
196210
}
197211

198-
/* Returns the dimensions of the matrix. */
212+
/** Returns the dimensions of the matrix. */
199213
def getDim: (Long, Long) = {
200214
val bi = getBlockInfo
201215
val xDim = bi.map { x =>
202216
(x._1._1, x._2.numRows.toLong)
203-
}.groupBy(x => x._1).values.map { x =>
204-
x.head._2.toLong
205-
}.reduceLeft {
206-
_ + _
207-
}
217+
}.groupBy(x => x._1).values.map(_.head._2.toLong).reduceLeft(_ + _)
208218

209219
val yDim = bi.map { x =>
210220
(x._1._2, x._2.numCols.toLong)
211-
}.groupBy(x => x._1).values.map { x =>
212-
x.head._2.toLong
213-
}.reduceLeft {
214-
_ + _
215-
}
221+
}.groupBy(x => x._1).values.map(_.head._2.toLong).reduceLeft(_ + _)
216222

217223
(xDim, yDim)
218224
}
219225

220-
/* Calculates the information for each block and collects it on the driver */
226+
/** Calculates the information for each block and collects it on the driver */
221227
private def calculateBlockInfo(): Unit = {
222228
// collect may cause akka frameSize errors
223229
val blockStartRowColsParts = matrixRDD.mapPartitionsWithIndex { case (partId, iter) =>
@@ -243,38 +249,38 @@ class BlockMatrix(
243249
}.toMap
244250

245251
blockInfo_ = blockStartRowCols.map{ case ((rowId, colId), (partId, numRow, numCol)) =>
246-
((rowId, colId), new BlockPartitionInfo(partId, rowId, colId, cumulativeRowSum(rowId),
252+
((rowId, colId), new SubMatrixInfo(partId, rowId, colId, cumulativeRowSum(rowId),
247253
numRow, cumulativeColSum(colId), numCol))
248254
}.toMap
249255
}
250256

251-
/* Returns a map of the information of the blocks that form the distributed matrix. */
252-
def getBlockInfo: Map[(Int, Int), BlockPartitionInfo] = {
257+
/** Returns a map of the information of the blocks that form the distributed matrix. */
258+
def getBlockInfo: Map[(Int, Int), SubMatrixInfo] = {
253259
if (blockInfo_ == null) {
254260
calculateBlockInfo()
255261
}
256262
blockInfo_
257263
}
258264

259-
/* Returns the Frobenius Norm of the matrix */
265+
/** Returns the Frobenius Norm of the matrix */
260266
def normFro(): Double = {
261267
math.sqrt(rdd.map(lm => lm.mat.values.map(x => math.pow(x, 2)).sum).reduce(_ + _))
262268
}
263269

264-
/* Cache the underlying RDD. */
270+
/** Cache the underlying RDD. */
265271
def cache(): DistributedMatrix = {
266272
matrixRDD.cache()
267273
this
268274
}
269275

270-
/* Set the storage level for the underlying RDD. */
276+
/** Set the storage level for the underlying RDD. */
271277
def persist(storageLevel: StorageLevel): DistributedMatrix = {
272278
matrixRDD.persist(storageLevel)
273279
this
274280
}
275281

276-
/* Add a key to the underlying rdd for partitioning and joins. */
277-
private def keyBy(part: BlockMatrixPartitioner = partitioner): RDD[(Int, BlockPartition)] = {
282+
/** Add a key to the underlying rdd for partitioning and joins. */
283+
private def keyBy(part: BlockMatrixPartitioner = partitioner): RDD[(Int, SubMatrix)] = {
278284
rdd.map { block =>
279285
part match {
280286
case r: RowBasedPartitioner => (block.blockIdRow, block)
@@ -296,7 +302,7 @@ class BlockMatrix(
296302
this
297303
}
298304

299-
/* Collect the distributed matrix on the driver. */
305+
/** Collect the distributed matrix on the driver. */
300306
def collect(): DenseMatrix = {
301307
val parts = rdd.map(x => ((x.blockIdRow, x.blockIdCol), x.mat)).
302308
collect().sortBy(x => (x._1._2, x._1._1))
@@ -324,6 +330,7 @@ class BlockMatrix(
324330
new DenseMatrix(nRows, nCols, values)
325331
}
326332

333+
/** Collects data and assembles a local dense breeze matrix (for test only). */
327334
private[mllib] def toBreeze(): BDM[Double] = {
328335
val localMat = collect()
329336
new BDM[Double](localMat.numRows, localMat.numCols, localMat.values)

0 commit comments

Comments
 (0)