Skip to content

Commit e1d3ee8

Browse files
committed
minor updates
1 parent 24ec7b8 commit e1d3ee8

File tree

1 file changed

+48
-43
lines changed

1 file changed

+48
-43
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ private[mllib] class GridPartitioner(
5252
* Returns the index of the partition the input coordinate belongs to.
5353
*
5454
* @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
55-
* multiplication.
55+
* multiplication. k is ignored in computing partitions.
5656
* @return The index of the partition, which the coordinate belongs to.
5757
*/
5858
override def getPartition(key: Any): Int = {
5959
key match {
6060
case (i: Int, j: Int) =>
6161
getPartitionId(i, j)
62-
case (i: Int, j: Int, _) =>
62+
case (i: Int, j: Int, _: Int) =>
6363
getPartitionId(i, j)
6464
case _ =>
65-
throw new IllegalArgumentException(s"Unrecognized key: $key")
65+
throw new IllegalArgumentException(s"Unrecognized key: $key.")
6666
}
6767
}
6868

@@ -73,7 +73,6 @@ private[mllib] class GridPartitioner(
7373
i / rowsPerPart + j / colsPerPart * rowPartitions
7474
}
7575

76-
/** Checks whether the partitioners have the same characteristics */
7776
override def equals(obj: Any): Boolean = {
7877
obj match {
7978
case r: GridPartitioner =>
@@ -87,10 +86,12 @@ private[mllib] class GridPartitioner(
8786

8887
private[mllib] object GridPartitioner {
8988

89+
/** Creates a new [[GridPartitioner]] instance. */
9090
def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = {
9191
new GridPartitioner(rows, cols, rowsPerPart, colsPerPart)
9292
}
9393

94+
/** Creates a new [[GridPartitioner]] instance with the input suggested number of partitions. */
9495
def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = {
9596
require(suggestedNumPartitions > 0)
9697
val scale = 1.0 / math.sqrt(suggestedNumPartitions)
@@ -103,24 +104,25 @@ private[mllib] object GridPartitioner {
103104
/**
104105
* Represents a distributed matrix in blocks of local matrices.
105106
*
106-
* @param rdd The RDD of SubMatrices (local matrices) that form this matrix
107-
* @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
108-
* the number of rows will be calculated when `numRows` is invoked.
109-
* @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
110-
* zero, the number of columns will be calculated when `numCols` is invoked.
107+
* @param blocks The RDD of sub-matrix blocks (blockRowIndex, blockColIndex, sub-matrix) that form
108+
* this distributed matrix.
111109
* @param rowsPerBlock Number of rows that make up each block. The blocks forming the final
112110
* rows are not required to have the given number of rows
113111
* @param colsPerBlock Number of columns that make up each block. The blocks forming the final
114112
* columns are not required to have the given number of columns
113+
* @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
114+
* the number of rows will be calculated when `numRows` is invoked.
115+
* @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
116+
* zero, the number of columns will be calculated when `numCols` is invoked.
115117
*/
116118
class BlockMatrix(
117-
val rdd: RDD[((Int, Int), Matrix)],
118-
private var nRows: Long,
119-
private var nCols: Long,
119+
val blocks: RDD[((Int, Int), Matrix)],
120120
val rowsPerBlock: Int,
121-
val colsPerBlock: Int) extends DistributedMatrix with Logging {
121+
val colsPerBlock: Int,
122+
private var nRows: Long,
123+
private var nCols: Long) extends DistributedMatrix with Logging {
122124

123-
private type SubMatrix = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), matrix)
125+
private type MatrixBlock = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), sub-matrix)
124126

125127
/**
126128
* Alternate constructor for BlockMatrix without the input of the number of rows and columns.
@@ -135,45 +137,48 @@ class BlockMatrix(
135137
rdd: RDD[((Int, Int), Matrix)],
136138
rowsPerBlock: Int,
137139
colsPerBlock: Int) = {
138-
this(rdd, 0L, 0L, rowsPerBlock, colsPerBlock)
140+
this(rdd, rowsPerBlock, colsPerBlock, 0L, 0L)
139141
}
140142

141-
private lazy val dims: (Long, Long) = getDim
142-
143143
override def numRows(): Long = {
144-
if (nRows <= 0L) nRows = dims._1
144+
if (nRows <= 0L) estimateDim()
145145
nRows
146146
}
147147

148148
override def numCols(): Long = {
149-
if (nCols <= 0L) nCols = dims._2
149+
if (nCols <= 0L) estimateDim()
150150
nCols
151151
}
152152

153153
val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt
154154
val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
155155

156156
private[mllib] var partitioner: GridPartitioner =
157-
GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = rdd.partitions.size)
158-
159-
/** Returns the dimensions of the matrix. */
160-
private def getDim: (Long, Long) = {
161-
val (rows, cols) = rdd.map { case ((blockRowIndex, blockColIndex), mat) =>
162-
(blockRowIndex * rowsPerBlock + mat.numRows, blockColIndex * colsPerBlock + mat.numCols)
163-
}.reduce((x0, x1) => (math.max(x0._1, x1._1), math.max(x0._2, x1._2)))
164-
165-
(math.max(rows, nRows), math.max(cols, nCols))
157+
GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size)
158+
159+
/** Estimates the dimensions of the matrix. */
160+
private def estimateDim(): Unit = {
161+
val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
162+
(blockRowIndex.toLong * rowsPerBlock + mat.numRows,
163+
blockColIndex.toLong * colsPerBlock + mat.numCols)
164+
}.reduce { (x0, x1) =>
165+
(math.max(x0._1, x1._1), math.max(x0._2, x1._2))
166+
}
167+
if (nRows <= 0L) nRows = rows
168+
assert(rows <= nRows, s"The number of rows $rows is more than claimed $nRows.")
169+
if (nCols <= 0L) nCols = cols
170+
assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.")
166171
}
167172

168-
/** Cache the underlying RDD. */
169-
def cache(): BlockMatrix = {
170-
rdd.cache()
173+
/** Caches the underlying RDD. */
174+
def cache(): this.type = {
175+
blocks.cache()
171176
this
172177
}
173178

174-
/** Set the storage level for the underlying RDD. */
175-
def persist(storageLevel: StorageLevel): BlockMatrix = {
176-
rdd.persist(storageLevel)
179+
/** Persists the underlying RDD with the specified storage level. */
180+
def persist(storageLevel: StorageLevel): this.type = {
181+
blocks.persist(storageLevel)
177182
this
178183
}
179184

@@ -185,22 +190,22 @@ class BlockMatrix(
185190
s"Int.MaxValue. Currently numCols: ${numCols()}")
186191
require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " +
187192
s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}")
188-
val nRows = numRows().toInt
189-
val nCols = numCols().toInt
190-
val mem = nRows * nCols / 125000
193+
val m = numRows().toInt
194+
val n = numCols().toInt
195+
val mem = m * n / 125000
191196
if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!")
192197

193-
val parts = rdd.collect()
194-
val values = new Array[Double](nRows * nCols)
195-
parts.foreach { case ((blockRowIndex, blockColIndex), block) =>
198+
val localBlocks = blocks.collect()
199+
val values = new Array[Double](m * n)
200+
localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) =>
196201
val rowOffset = blockRowIndex * rowsPerBlock
197202
val colOffset = blockColIndex * colsPerBlock
198-
block.foreachActive { (i, j, v) =>
199-
val indexOffset = (j + colOffset) * nRows + rowOffset + i
203+
submat.foreachActive { (i, j, v) =>
204+
val indexOffset = (j + colOffset) * m + rowOffset + i
200205
values(indexOffset) = v
201206
}
202207
}
203-
new DenseMatrix(nRows, nCols, values)
208+
new DenseMatrix(m, n, values)
204209
}
205210

206211
/** Collects data and assembles a local dense breeze matrix (for test only). */

0 commit comments

Comments
 (0)