Skip to content

Commit 5eecd48

Browse files
committed
fixed gridPartitioner and added tests
1 parent 140f20e commit 5eecd48

File tree

2 files changed

+110
-40
lines changed

2 files changed

+110
-40
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.mllib.linalg.distributed
1919

2020
import breeze.linalg.{DenseMatrix => BDM}
21+
import org.apache.spark.util.Utils
2122

2223
import org.apache.spark.{Logging, Partitioner}
2324
import org.apache.spark.mllib.linalg._
@@ -39,16 +40,17 @@ private[mllib] class GridPartitioner(
3940
val numRowBlocks: Int,
4041
val numColBlocks: Int,
4142
suggestedNumPartitions: Int) extends Partitioner {
43+
private val totalBlocks = numRowBlocks.toLong * numColBlocks
4244
// Having the number of partitions greater than the number of sub matrices does not help
43-
override val numPartitions = math.min(suggestedNumPartitions, numRowBlocks * numColBlocks)
45+
override val numPartitions = math.min(suggestedNumPartitions, totalBlocks).toInt
4446

45-
val totalBlocks = numRowBlocks.toLong * numColBlocks
46-
// Gives the number of blocks that need to be in each partition
47-
val targetNumBlocksPerPartition = math.ceil(totalBlocks * 1.0 / numPartitions).toInt
47+
private val blockLengthsPerPartition = findOptimalBlockLengths
4848
// Number of neighboring blocks to take in each row
49-
val numRowBlocksPerPartition = math.ceil(numRowBlocks * 1.0 / targetNumBlocksPerPartition).toInt
49+
private val numRowBlocksPerPartition = blockLengthsPerPartition._1
5050
// Number of neighboring blocks to take in each column
51-
val numColBlocksPerPartition = math.ceil(numColBlocks * 1.0 / targetNumBlocksPerPartition).toInt
51+
private val numColBlocksPerPartition = blockLengthsPerPartition._2
52+
// Number of rows of partitions
53+
private val blocksPerRow = math.ceil(numRowBlocks * 1.0 / numRowBlocksPerPartition).toInt
5254

5355
/**
5456
* Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
@@ -73,11 +75,54 @@ private[mllib] class GridPartitioner(
7375

7476
/** Partitions sub-matrices as blocks with neighboring sub-matrices. */
7577
private def getPartitionId(blockRowIndex: Int, blockColIndex: Int): Int = {
78+
require(0 <= blockRowIndex && blockRowIndex < numRowBlocks, "The blockRowIndex in the key " +
79+
s"must be in the range 0 <= blockRowIndex < numRowBlocks. blockRowIndex: $blockRowIndex," +
80+
s"numRowBlocks: $numRowBlocks")
81+
require(0 <= blockRowIndex && blockColIndex < numColBlocks, "The blockColIndex in the key " +
82+
s"must be in the range 0 <= blockRowIndex < numColBlocks. blockColIndex: $blockColIndex, " +
83+
s"numColBlocks: $numColBlocks")
7684
// Coordinates of the block
7785
val i = blockRowIndex / numRowBlocksPerPartition
7886
val j = blockColIndex / numColBlocksPerPartition
79-
val blocksPerRow = math.ceil(numRowBlocks * 1.0 / numRowBlocksPerPartition).toInt
80-
j * blocksPerRow + i
87+
// The mod shouldn't be required but is added as a guarantee for possible corner cases
88+
Utils.nonNegativeMod(j * blocksPerRow + i, numPartitions)
89+
}
90+
91+
/** Tries to calculate the optimal number of blocks that should be in each partition. */
92+
private def findOptimalBlockLengths: (Int, Int) = {
93+
// Gives the optimal number of blocks that need to be in each partition
94+
val targetNumBlocksPerPartition = math.ceil(totalBlocks * 1.0 / numPartitions).toInt
95+
// Number of neighboring blocks to take in each row
96+
var m = math.ceil(math.sqrt(targetNumBlocksPerPartition)).toInt
97+
// Number of neighboring blocks to take in each column
98+
var n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
99+
// Try to make m and n close to each other while making sure that we don't exceed the number
100+
// of partitions
101+
var numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
102+
var numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
103+
while ((numBlocksForRows * numBlocksForCols > numPartitions) && (m * n != 0)) {
104+
if (numRowBlocks <= numColBlocks) {
105+
m += 1
106+
n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
107+
} else {
108+
n += 1
109+
m = math.ceil(targetNumBlocksPerPartition * 1.0 / n).toInt
110+
}
111+
numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
112+
numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
113+
}
114+
// If a good partitioning scheme couldn't be found, set the side with the smaller dimension to
115+
// 1 and the other to the number of targetNumBlocksPerPartition
116+
if (m * n == 0) {
117+
if (numRowBlocks <= numColBlocks) {
118+
m = 1
119+
n = targetNumBlocksPerPartition
120+
} else {
121+
n = 1
122+
m = targetNumBlocksPerPartition
123+
}
124+
}
125+
(m, n)
81126
}
82127

83128
/** Checks whether the partitioners have the same characteristics */
@@ -148,8 +193,6 @@ class BlockMatrix(
148193
private[mllib] var partitioner: GridPartitioner =
149194
new GridPartitioner(numRowBlocks, numColBlocks, rdd.partitions.length)
150195

151-
152-
153196
/** Returns the dimensions of the matrix. */
154197
private def getDim: (Long, Long) = {
155198
val (rows, cols) = rdd.map { case ((blockRowIndex, blockColIndex), mat) =>
@@ -177,27 +220,21 @@ class BlockMatrix(
177220
s"Int.MaxValue. Currently numRows: ${numRows()}")
178221
require(numCols() < Int.MaxValue, "The number of columns of this matrix should be less than " +
179222
s"Int.MaxValue. Currently numCols: ${numCols()}")
223+
require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " +
224+
s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}")
180225
val nRows = numRows().toInt
181226
val nCols = numCols().toInt
182-
val mem = nRows.toLong * nCols / 125000
227+
val mem = nRows * nCols / 125000
183228
if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!")
184229

185230
val parts = rdd.collect()
186231
val values = new Array[Double](nRows * nCols)
187232
parts.foreach { case ((blockRowIndex, blockColIndex), block) =>
188233
val rowOffset = blockRowIndex * rowsPerBlock
189234
val colOffset = blockColIndex * colsPerBlock
190-
var j = 0
191-
val mat = block.toArray
192-
while (j < block.numCols) {
193-
var i = 0
194-
val indStart = (j + colOffset) * nRows + rowOffset
195-
val matStart = j * block.numRows
196-
while (i < block.numRows) {
197-
values(indStart + i) = mat(matStart + i)
198-
i += 1
199-
}
200-
j += 1
235+
block.foreachActive { (i, j, v) =>
236+
val indexOffset = (j + colOffset) * nRows + rowOffset + i
237+
values(indexOffset) = v
201238
}
202239
}
203240
new DenseMatrix(nRows, nCols, values)

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,13 @@ import breeze.linalg.{DenseMatrix => BDM}
2323
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix}
2424
import org.apache.spark.mllib.util.MLlibTestSparkContext
2525

26-
// Input values for the tests
27-
private object BlockMatrixSuite {
26+
class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
27+
2828
val m = 5
2929
val n = 4
3030
val rowPerPart = 2
3131
val colPerPart = 2
32-
val numRowBlocks = 3
33-
val numColBlocks = 2
34-
}
35-
36-
class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
37-
38-
val m = BlockMatrixSuite.m
39-
val n = BlockMatrixSuite.n
40-
val rowPerPart = BlockMatrixSuite.rowPerPart
41-
val colPerPart = BlockMatrixSuite.colPerPart
42-
val numRowBlocks = BlockMatrixSuite.numRowBlocks
43-
val numColBlocks = BlockMatrixSuite.numColBlocks
32+
val numPartitions = 3
4433
var gridBasedMat: BlockMatrix = _
4534
type SubMatrix = ((Int, Int), Matrix)
4635

@@ -54,14 +43,58 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
5443
new SubMatrix((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))),
5544
new SubMatrix((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0))))
5645

57-
gridBasedMat = new BlockMatrix(sc.parallelize(entries, 2), numRowBlocks, numColBlocks,
58-
rowPerPart, colPerPart)
46+
gridBasedMat = new BlockMatrix(sc.parallelize(entries, numPartitions), rowPerPart, colPerPart)
5947
}
6048

61-
test("size and frobenius norm") {
49+
test("size") {
6250
assert(gridBasedMat.numRows() === m)
6351
assert(gridBasedMat.numCols() === n)
64-
assert(gridBasedMat.normFro() === 7.0)
52+
}
53+
54+
test("grid partitioner partitioning") {
55+
val partitioner = gridBasedMat.partitioner
56+
assert(partitioner.getPartition((0, 0)) === 0)
57+
assert(partitioner.getPartition((0, 1)) === 0)
58+
assert(partitioner.getPartition((1, 0)) === 1)
59+
assert(partitioner.getPartition((1, 1)) === 1)
60+
assert(partitioner.getPartition((2, 0)) === 2)
61+
assert(partitioner.getPartition((2, 1)) === 2)
62+
assert(partitioner.getPartition((1, 0, 1)) === 1)
63+
assert(partitioner.getPartition((2, 0, 0)) === 2)
64+
65+
val part2 = new GridPartitioner(10, 20, 10)
66+
assert(part2.getPartition((0, 0)) === 0)
67+
assert(part2.getPartition((0, 1)) === 0)
68+
assert(part2.getPartition((0, 6)) === 2)
69+
assert(part2.getPartition((3, 7)) === 2)
70+
assert(part2.getPartition((3, 8)) === 4)
71+
assert(part2.getPartition((3, 13)) === 6)
72+
assert(part2.getPartition((9, 14)) === 7)
73+
assert(part2.getPartition((9, 15)) === 7)
74+
assert(part2.getPartition((9, 19)) === 9)
75+
76+
intercept[IllegalArgumentException] {
77+
part2.getPartition((-1, 0))
78+
}
79+
80+
intercept[IllegalArgumentException] {
81+
part2.getPartition((10, 0))
82+
}
83+
84+
intercept[IllegalArgumentException] {
85+
part2.getPartition((9, 20))
86+
}
87+
88+
val part3 = new GridPartitioner(20, 10, 10)
89+
assert(part3.getPartition((0, 0)) === 0)
90+
assert(part3.getPartition((1, 0)) === 0)
91+
assert(part3.getPartition((6, 0)) === 1)
92+
assert(part3.getPartition((7, 3)) === 1)
93+
assert(part3.getPartition((8, 3)) === 2)
94+
assert(part3.getPartition((13, 3)) === 3)
95+
assert(part3.getPartition((14, 9)) === 8)
96+
assert(part3.getPartition((15, 9)) === 8)
97+
assert(part3.getPartition((19, 9)) === 9)
6598
}
6699

67100
test("toBreeze and toLocalMatrix") {

0 commit comments

Comments
 (0)