Skip to content

Commit 9ae85aa

Browse files
committed
[SPARK-3974] Made partitioner a variable inside BlockMatrix instead of a constructor variable
1 parent d033861 commit 9ae85aa

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,11 @@ class ColumnBasedPartitioner(
155155
* @param numRowBlocks Number of blocks that form the rows of this matrix
156156
* @param numColBlocks Number of blocks that form the columns of this matrix
157157
* @param rdd The RDD of SubMatrices (local matrices) that form this matrix
158-
* @param partitioner A partitioner that specifies how SubMatrices are stored in the cluster
159158
*/
160159
class BlockMatrix(
161160
val numRowBlocks: Int,
162161
val numColBlocks: Int,
163-
val rdd: RDD[SubMatrix],
164-
val partitioner: BlockMatrixPartitioner) extends DistributedMatrix with Logging {
162+
val rdd: RDD[SubMatrix]) extends DistributedMatrix with Logging {
165163

166164
/**
167165
* Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid
@@ -170,11 +168,31 @@ class BlockMatrix(
170168
* @param numRowBlocks Number of blocks that form the rows of this matrix
171169
* @param numColBlocks Number of blocks that form the columns of this matrix
172170
* @param rdd The RDD of SubMatrices (local matrices) that form this matrix
171+
* @param partitioner A partitioner that specifies how SubMatrices are stored in the cluster
173172
*/
174-
def this(numRowBlocks: Int, numColBlocks: Int, rdd: RDD[SubMatrix]) = {
175-
this(numRowBlocks, numColBlocks, rdd, new GridPartitioner(numRowBlocks, numColBlocks,
176-
rdd.first().mat.numRows, rdd.first().mat.numCols))
173+
def this(
174+
numRowBlocks: Int,
175+
numColBlocks: Int,
176+
rdd: RDD[SubMatrix],
177+
partitioner: BlockMatrixPartitioner) = {
178+
this(numRowBlocks, numColBlocks, rdd)
179+
setPartitioner(partitioner)
177180
}
181+
182+
private[mllib] var partitioner: BlockMatrixPartitioner = {
183+
val firstSubMatrix = rdd.first().mat
184+
new GridPartitioner(numRowBlocks, numColBlocks,
185+
firstSubMatrix.numRows, firstSubMatrix.numCols)
186+
}
187+
188+
/**
189+
* Set the partitioner for the matrix. For internal use only. Users should use `repartition`.
190+
* @param part A partitioner that specifies how SubMatrices are stored in the cluster
191+
*/
192+
private def setPartitioner(part: BlockMatrixPartitioner): Unit = {
193+
partitioner = part
194+
}
195+
178196
// A key-value pair RDD is required to partition properly
179197
private var matrixRDD: RDD[(Int, SubMatrix)] = keyBy()
180198

@@ -259,8 +277,9 @@ class BlockMatrix(
259277
* @param part The partitioner to partition by
260278
* @return The repartitioned BlockMatrix
261279
*/
262-
def repartition(part: BlockMatrixPartitioner = partitioner): DistributedMatrix = {
280+
def repartition(part: BlockMatrixPartitioner): DistributedMatrix = {
263281
matrixRDD = keyBy(part)
282+
setPartitioner(part)
264283
this
265284
}
266285

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
6666
assert(gridBasedMat.numCols() === n)
6767
}
6868

69+
test("partitioner and repartition") {
70+
assert(colBasedMat.partitioner.name === "column")
71+
assert(rowBasedMat.partitioner.name === "row")
72+
assert(gridBasedMat.partitioner.name === "grid")
73+
74+
val colPart = new ColumnBasedPartitioner(numColBlocks, rowPerPart, colPerPart)
75+
val rowPart = new RowBasedPartitioner(numRowBlocks, rowPerPart, colPerPart)
76+
gridBasedMat.repartition(rowPart).asInstanceOf[BlockMatrix]
77+
assert(gridBasedMat.partitioner.name === "row")
78+
79+
gridBasedMat.repartition(colPart).asInstanceOf[BlockMatrix]
80+
assert(gridBasedMat.partitioner.name === "column")
81+
}
82+
6983
test("toBreeze and collect") {
7084
val expected = BDM(
7185
(1.0, 0.0, 0.0, 0.0),

0 commit comments

Comments
 (0)