Skip to content

Commit 88faede

Browse files
author
Andrew Or
committed
Add unit test for StaticMemoryManager
1 parent 52e2014 commit 88faede

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.mockito.Mockito.{mock, reset, verify, when}
23+
import org.mockito.Matchers.{any, eq => meq}
24+
25+
import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId}
26+
27+
28+
class StaticMemoryManagerSuite extends SparkFunSuite {
29+
private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4")
30+
31+
test("basic execution memory") {
32+
val maxExecutionMem = 1000L
33+
val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue)
34+
assert(mm.executionMemoryUsed === 0L)
35+
assert(mm.acquireExecutionMemory(10L) === 10L)
36+
assert(mm.executionMemoryUsed === 10L)
37+
assert(mm.acquireExecutionMemory(100L) === 100L)
38+
// Acquire up to the max
39+
assert(mm.acquireExecutionMemory(1000L) === 890L)
40+
assert(mm.executionMemoryUsed === 1000L)
41+
assert(mm.acquireExecutionMemory(1L) === 0L)
42+
assert(mm.executionMemoryUsed === 1000L)
43+
mm.releaseExecutionMemory(800L)
44+
assert(mm.executionMemoryUsed === 200L)
45+
// Acquire after release
46+
assert(mm.acquireExecutionMemory(1L) === 1L)
47+
assert(mm.executionMemoryUsed === 201L)
48+
// Release beyond what was acquired
49+
mm.releaseExecutionMemory(maxExecutionMem)
50+
assert(mm.executionMemoryUsed === 0L)
51+
}
52+
53+
test("basic storage memory") {
54+
val maxStorageMem = 1000L
55+
val dummyBlock = TestBlockId("you can see the world you brought to live")
56+
val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
57+
val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem)
58+
assert(mm.storageMemoryUsed === 0L)
59+
assert(mm.acquireStorageMemory(dummyBlock, 10L, dummyBlocks) === 10L)
60+
// `ensureFreeSpace` should be called with the number of bytes requested
61+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 10L)
62+
assert(mm.storageMemoryUsed === 10L)
63+
assert(dummyBlocks.isEmpty)
64+
assert(mm.acquireStorageMemory(dummyBlock, 100L, dummyBlocks) === 100L)
65+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L)
66+
// Acquire up to the max, not granted
67+
assert(mm.acquireStorageMemory(dummyBlock, 1000L, dummyBlocks) === 0L)
68+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 1000L)
69+
assert(mm.storageMemoryUsed === 110L)
70+
assert(mm.acquireStorageMemory(dummyBlock, 890L, dummyBlocks) === 890L)
71+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 890L)
72+
assert(mm.storageMemoryUsed === 1000L)
73+
assert(mm.acquireStorageMemory(dummyBlock, 1L, dummyBlocks) === 0L)
74+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L)
75+
assert(mm.storageMemoryUsed === 1000L)
76+
mm.releaseStorageMemory(800L)
77+
assert(mm.storageMemoryUsed === 200L)
78+
// Acquire after release
79+
assert(mm.acquireStorageMemory(dummyBlock, 1L, dummyBlocks) === 1L)
80+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L)
81+
assert(mm.storageMemoryUsed === 201L)
82+
mm.releaseStorageMemory()
83+
assert(mm.storageMemoryUsed === 0L)
84+
assert(mm.acquireStorageMemory(dummyBlock, 1L, dummyBlocks) === 1L)
85+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L)
86+
assert(mm.storageMemoryUsed === 1L)
87+
// Release beyond what was acquired
88+
mm.releaseStorageMemory(100L)
89+
assert(mm.storageMemoryUsed === 0L)
90+
}
91+
92+
test("execution and storage isolation") {
93+
val maxExecutionMem = 200L
94+
val maxStorageMem = 1000L
95+
val dummyBlock = TestBlockId("ain't nobody love like you do")
96+
val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
97+
val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem)
98+
// Only execution memory should increase
99+
assert(mm.acquireExecutionMemory(100L) === 100L)
100+
assert(mm.storageMemoryUsed === 0L)
101+
assert(mm.executionMemoryUsed === 100L)
102+
assert(mm.acquireExecutionMemory(1000L) === 100L)
103+
assert(mm.storageMemoryUsed === 0L)
104+
assert(mm.executionMemoryUsed === 200L)
105+
// Only storage memory should increase
106+
assert(mm.acquireStorageMemory(dummyBlock, 50L, dummyBlocks) === 50L)
107+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 50L)
108+
assert(mm.storageMemoryUsed === 50L)
109+
assert(mm.executionMemoryUsed === 200L)
110+
// Only execution memory should be released
111+
mm.releaseExecutionMemory(133L)
112+
assert(mm.storageMemoryUsed === 50L)
113+
assert(mm.executionMemoryUsed === 67L)
114+
// Only storage memory should be released
115+
mm.releaseStorageMemory()
116+
assert(mm.storageMemoryUsed === 0L)
117+
assert(mm.executionMemoryUsed === 67L)
118+
}
119+
120+
test("unroll memory") {
121+
val maxStorageMem = 1000L
122+
val dummyBlock = TestBlockId("lonely water")
123+
val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
124+
val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem)
125+
assert(mm.acquireUnrollMemory(dummyBlock, 100L, dummyBlocks) === 100L)
126+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L)
127+
assert(mm.storageMemoryUsed === 100L)
128+
mm.releaseUnrollMemory(40L)
129+
assert(mm.storageMemoryUsed === 60L)
130+
when(ms.currentUnrollMemory).thenReturn(60L)
131+
assert(mm.acquireUnrollMemory(dummyBlock, 500L, dummyBlocks) === 500L)
132+
// `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes.
133+
// Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes.
134+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 340L)
135+
assert(mm.storageMemoryUsed === 560L)
136+
when(ms.currentUnrollMemory).thenReturn(560L)
137+
assert(mm.acquireUnrollMemory(dummyBlock, 800L, dummyBlocks) === 0L)
138+
assert(mm.storageMemoryUsed === 560L)
139+
// We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed
140+
assertEnsureFreeSpaceCalled(ms, dummyBlock, 0L)
141+
// Release beyond what was acquired
142+
mm.releaseUnrollMemory(maxStorageMem)
143+
assert(mm.storageMemoryUsed === 0L)
144+
}
145+
146+
/**
147+
* Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies.
148+
*/
149+
private def makeThings(
150+
maxExecutionMem: Long,
151+
maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = {
152+
val mm = new StaticMemoryManager(conf, maxExecutionMem, maxStorageMem)
153+
val ms = mock(classOf[MemoryStore])
154+
mm.setMemoryStore(ms)
155+
(mm, ms)
156+
}
157+
158+
/**
159+
* Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters.
160+
*/
161+
private def assertEnsureFreeSpaceCalled(
162+
ms: MemoryStore,
163+
blockId: BlockId,
164+
numBytes: Long): Unit = {
165+
verify(ms).ensureFreeSpace(meq(blockId), meq(numBytes: java.lang.Long), any())
166+
reset(ms)
167+
}
168+
169+
}

0 commit comments

Comments
 (0)