Skip to content

Commit 24ec7b8

Browse files
committed
update grid partitioner
1 parent 5eecd48 commit 24ec7b8

File tree

2 files changed

+109
-122
lines changed

2 files changed

+109
-122
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 50 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -18,125 +18,88 @@
1818
package org.apache.spark.mllib.linalg.distributed
1919

2020
import breeze.linalg.{DenseMatrix => BDM}
21-
import org.apache.spark.util.Utils
2221

2322
import org.apache.spark.{Logging, Partitioner}
24-
import org.apache.spark.mllib.linalg._
25-
import org.apache.spark.mllib.rdd.RDDFunctions._
23+
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix}
2624
import org.apache.spark.rdd.RDD
2725
import org.apache.spark.storage.StorageLevel
2826

2927
/**
30-
* A grid partitioner, which stores every block in a separate partition.
28+
* A grid partitioner, which uses a regular grid to partition coordinates.
3129
*
32-
* @param numRowBlocks Number of blocks that form the rows of the matrix.
33-
* @param numColBlocks Number of blocks that form the columns of the matrix.
34-
* @param suggestedNumPartitions Number of partitions to partition the rdd into. The final number
35-
* of partitions will be set to `min(suggestedNumPartitions,
36-
* numRowBlocks * numColBlocks)`, because setting the number of
37-
* partitions greater than the number of sub matrices is not useful.
30+
* @param rows Number of rows.
31+
* @param cols Number of columns.
32+
* @param rowsPerPart Number of rows per partition, which may be less at the bottom edge.
33+
* @param colsPerPart Number of columns per partition, which may be less at the right edge.
3834
*/
3935
private[mllib] class GridPartitioner(
40-
val numRowBlocks: Int,
41-
val numColBlocks: Int,
42-
suggestedNumPartitions: Int) extends Partitioner {
43-
private val totalBlocks = numRowBlocks.toLong * numColBlocks
44-
// Having the number of partitions greater than the number of sub matrices does not help
45-
override val numPartitions = math.min(suggestedNumPartitions, totalBlocks).toInt
46-
47-
private val blockLengthsPerPartition = findOptimalBlockLengths
48-
// Number of neighboring blocks to take in each row
49-
private val numRowBlocksPerPartition = blockLengthsPerPartition._1
50-
// Number of neighboring blocks to take in each column
51-
private val numColBlocksPerPartition = blockLengthsPerPartition._2
52-
// Number of rows of partitions
53-
private val blocksPerRow = math.ceil(numRowBlocks * 1.0 / numRowBlocksPerPartition).toInt
36+
val rows: Int,
37+
val cols: Int,
38+
val rowsPerPart: Int,
39+
val colsPerPart: Int) extends Partitioner {
40+
41+
require(rows > 0)
42+
require(cols > 0)
43+
require(rowsPerPart > 0)
44+
require(colsPerPart > 0)
45+
46+
private val rowPartitions = math.ceil(rows / rowsPerPart).toInt
47+
private val colPartitions = math.ceil(cols / colsPerPart).toInt
48+
49+
override val numPartitions = rowPartitions * colPartitions
5450

5551
/**
56-
* Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
57-
* partitioning.
52+
* Returns the index of the partition the input coordinate belongs to.
5853
*
59-
* @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
60-
* or a tuple of three integers that are the final row index after the multiplication,
61-
* the index of the block to multiply with, and the final column index after the
54+
* @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
6255
* multiplication.
63-
* @return The index of the partition, which the SubMatrix belongs to.
56+
* @return The index of the partition, which the coordinate belongs to.
6457
*/
6558
override def getPartition(key: Any): Int = {
6659
key match {
67-
case (blockRowIndex: Int, blockColIndex: Int) =>
68-
getPartitionId(blockRowIndex, blockColIndex)
69-
case (blockRowIndex: Int, innerIndex: Int, blockColIndex: Int) =>
70-
getPartitionId(blockRowIndex, blockColIndex)
60+
case (i: Int, j: Int) =>
61+
getPartitionId(i, j)
62+
case (i: Int, j: Int, _) =>
63+
getPartitionId(i, j)
7164
case _ =>
72-
throw new IllegalArgumentException(s"Unrecognized key. key: $key")
65+
throw new IllegalArgumentException(s"Unrecognized key: $key")
7366
}
7467
}
7568

7669
/** Partitions sub-matrices as blocks with neighboring sub-matrices. */
77-
private def getPartitionId(blockRowIndex: Int, blockColIndex: Int): Int = {
78-
require(0 <= blockRowIndex && blockRowIndex < numRowBlocks, "The blockRowIndex in the key " +
79-
s"must be in the range 0 <= blockRowIndex < numRowBlocks. blockRowIndex: $blockRowIndex," +
80-
s"numRowBlocks: $numRowBlocks")
81-
require(0 <= blockRowIndex && blockColIndex < numColBlocks, "The blockColIndex in the key " +
82-
s"must be in the range 0 <= blockRowIndex < numColBlocks. blockColIndex: $blockColIndex, " +
83-
s"numColBlocks: $numColBlocks")
84-
// Coordinates of the block
85-
val i = blockRowIndex / numRowBlocksPerPartition
86-
val j = blockColIndex / numColBlocksPerPartition
87-
// The mod shouldn't be required but is added as a guarantee for possible corner cases
88-
Utils.nonNegativeMod(j * blocksPerRow + i, numPartitions)
89-
}
90-
91-
/** Tries to calculate the optimal number of blocks that should be in each partition. */
92-
private def findOptimalBlockLengths: (Int, Int) = {
93-
// Gives the optimal number of blocks that need to be in each partition
94-
val targetNumBlocksPerPartition = math.ceil(totalBlocks * 1.0 / numPartitions).toInt
95-
// Number of neighboring blocks to take in each row
96-
var m = math.ceil(math.sqrt(targetNumBlocksPerPartition)).toInt
97-
// Number of neighboring blocks to take in each column
98-
var n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
99-
// Try to make m and n close to each other while making sure that we don't exceed the number
100-
// of partitions
101-
var numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
102-
var numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
103-
while ((numBlocksForRows * numBlocksForCols > numPartitions) && (m * n != 0)) {
104-
if (numRowBlocks <= numColBlocks) {
105-
m += 1
106-
n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
107-
} else {
108-
n += 1
109-
m = math.ceil(targetNumBlocksPerPartition * 1.0 / n).toInt
110-
}
111-
numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
112-
numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
113-
}
114-
// If a good partitioning scheme couldn't be found, set the side with the smaller dimension to
115-
// 1 and the other to the number of targetNumBlocksPerPartition
116-
if (m * n == 0) {
117-
if (numRowBlocks <= numColBlocks) {
118-
m = 1
119-
n = targetNumBlocksPerPartition
120-
} else {
121-
n = 1
122-
m = targetNumBlocksPerPartition
123-
}
124-
}
125-
(m, n)
70+
private def getPartitionId(i: Int, j: Int): Int = {
71+
require(0 <= i && i < rows, s"Row index $i out of range [0, $rows).")
72+
require(0 <= j && j < cols, s"Column index $j out of range [0, $cols).")
73+
i / rowsPerPart + j / colsPerPart * rowPartitions
12674
}
12775

12876
/** Checks whether the partitioners have the same characteristics */
12977
override def equals(obj: Any): Boolean = {
13078
obj match {
13179
case r: GridPartitioner =>
132-
(this.numRowBlocks == r.numRowBlocks) && (this.numColBlocks == r.numColBlocks) &&
133-
(this.numPartitions == r.numPartitions)
80+
(this.rows == r.rows) && (this.cols == r.cols) &&
81+
(this.rowsPerPart == r.rowsPerPart) && (this.colsPerPart == r.colsPerPart)
13482
case _ =>
13583
false
13684
}
13785
}
13886
}
13987

88+
private[mllib] object GridPartitioner {
89+
90+
def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = {
91+
new GridPartitioner(rows, cols, rowsPerPart, colsPerPart)
92+
}
93+
94+
def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = {
95+
require(suggestedNumPartitions > 0)
96+
val scale = 1.0 / math.sqrt(suggestedNumPartitions)
97+
val rowsPerPart = math.round(math.max(scale * rows, 1.0)).toInt
98+
val colsPerPart = math.round(math.max(scale * cols, 1.0)).toInt
99+
new GridPartitioner(rows, cols, rowsPerPart, colsPerPart)
100+
}
101+
}
102+
140103
/**
141104
* Represents a distributed matrix in blocks of local matrices.
142105
*
@@ -191,7 +154,7 @@ class BlockMatrix(
191154
val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
192155

193156
private[mllib] var partitioner: GridPartitioner =
194-
new GridPartitioner(numRowBlocks, numColBlocks, rdd.partitions.length)
157+
GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = rdd.partitions.size)
195158

196159
/** Returns the dimensions of the matrix. */
197160
private def getDim: (Long, Long) = {

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
package org.apache.spark.mllib.linalg.distributed
1919

20-
import org.scalatest.FunSuite
20+
import scala.util.Random
21+
2122
import breeze.linalg.{DenseMatrix => BDM}
23+
import org.scalatest.FunSuite
2224

2325
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix}
2426
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -51,50 +53,72 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
5153
assert(gridBasedMat.numCols() === n)
5254
}
5355

54-
test("grid partitioner partitioning") {
55-
val partitioner = gridBasedMat.partitioner
56-
assert(partitioner.getPartition((0, 0)) === 0)
57-
assert(partitioner.getPartition((0, 1)) === 0)
58-
assert(partitioner.getPartition((1, 0)) === 1)
59-
assert(partitioner.getPartition((1, 1)) === 1)
60-
assert(partitioner.getPartition((2, 0)) === 2)
61-
assert(partitioner.getPartition((2, 1)) === 2)
62-
assert(partitioner.getPartition((1, 0, 1)) === 1)
63-
assert(partitioner.getPartition((2, 0, 0)) === 2)
64-
65-
val part2 = new GridPartitioner(10, 20, 10)
66-
assert(part2.getPartition((0, 0)) === 0)
67-
assert(part2.getPartition((0, 1)) === 0)
68-
assert(part2.getPartition((0, 6)) === 2)
69-
assert(part2.getPartition((3, 7)) === 2)
70-
assert(part2.getPartition((3, 8)) === 4)
71-
assert(part2.getPartition((3, 13)) === 6)
72-
assert(part2.getPartition((9, 14)) === 7)
73-
assert(part2.getPartition((9, 15)) === 7)
74-
assert(part2.getPartition((9, 19)) === 9)
56+
test("grid partitioner") {
57+
val random = new Random()
58+
// This should generate a 4x4 grid of 1x2 blocks.
59+
val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12)
60+
val expected0 = Array(
61+
Array(0, 0, 4, 4, 8, 8, 12),
62+
Array(1, 1, 5, 5, 9, 9, 13),
63+
Array(2, 2, 6, 6, 10, 10, 14),
64+
Array(3, 3, 7, 7, 11, 11, 15))
65+
for (i <- 0 until 4; j <- 0 until 7) {
66+
assert(part0.getPartition((i, j)) === expected0(i)(j))
67+
assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j))
68+
}
69+
70+
intercept[IllegalArgumentException] {
71+
part0.getPartition((-1, 0))
72+
}
73+
74+
intercept[IllegalArgumentException] {
75+
part0.getPartition((4, 0))
76+
}
77+
78+
intercept[IllegalArgumentException] {
79+
part0.getPartition((0, -1))
80+
}
7581

7682
intercept[IllegalArgumentException] {
77-
part2.getPartition((-1, 0))
83+
part0.getPartition((0, 7))
84+
}
85+
86+
val part1 = GridPartitioner(2, 2, suggestedNumPartitions = 5)
87+
val expected1 = Array(
88+
Array(0, 2),
89+
Array(1, 3))
90+
for (i <- 0 until 2; j <- 0 until 2) {
91+
assert(part1.getPartition((i, j)) === expected1(i)(j))
92+
assert(part1.getPartition((i, j, random.nextInt())) === expected1(i)(j))
7893
}
7994

95+
val part2 = GridPartitioner(2, 2, suggestedNumPartitions = 5)
96+
assert(part0 !== part2)
97+
assert(part1 === part2)
98+
99+
val part3 = new GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2)
100+
val expected3 = Array(
101+
Array(0, 0, 2),
102+
Array(1, 1, 3))
103+
for (i <- 0 until 2; j <- 0 until 3) {
104+
assert(part3.getPartition((i, j)) === expected3(i)(j))
105+
assert(part3.getPartition((i, j, random.nextInt())) === expected3(i)(j))
106+
}
107+
108+
val part4 = GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2)
109+
assert(part3 === part4)
110+
80111
intercept[IllegalArgumentException] {
81-
part2.getPartition((10, 0))
112+
new GridPartitioner(2, 2, rowsPerPart = 0, colsPerPart = 1)
82113
}
83114

84115
intercept[IllegalArgumentException] {
85-
part2.getPartition((9, 20))
116+
GridPartitioner(2, 2, rowsPerPart = 1, colsPerPart = 0)
86117
}
87118

88-
val part3 = new GridPartitioner(20, 10, 10)
89-
assert(part3.getPartition((0, 0)) === 0)
90-
assert(part3.getPartition((1, 0)) === 0)
91-
assert(part3.getPartition((6, 0)) === 1)
92-
assert(part3.getPartition((7, 3)) === 1)
93-
assert(part3.getPartition((8, 3)) === 2)
94-
assert(part3.getPartition((13, 3)) === 3)
95-
assert(part3.getPartition((14, 9)) === 8)
96-
assert(part3.getPartition((15, 9)) === 8)
97-
assert(part3.getPartition((19, 9)) === 9)
119+
intercept[IllegalArgumentException] {
120+
GridPartitioner(2, 2, suggestedNumPartitions = 0)
121+
}
98122
}
99123

100124
test("toBreeze and toLocalMatrix") {

0 commit comments

Comments
 (0)