|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark |
| 19 | + |
| 20 | +import scala.collection.mutable.ArrayBuffer |
| 21 | + |
| 22 | +import org.mockito.Mockito.{mock, reset, verify, when} |
| 23 | +import org.mockito.Matchers.{any, eq => meq} |
| 24 | + |
| 25 | +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} |
| 26 | + |
| 27 | + |
| 28 | +class StaticMemoryManagerSuite extends SparkFunSuite { |
| 29 | + private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") |
| 30 | + |
| 31 | + test("basic execution memory") { |
| 32 | + val maxExecutionMem = 1000L |
| 33 | + val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) |
| 34 | + assert(mm.executionMemoryUsed === 0L) |
| 35 | + assert(mm.acquireExecutionMemory(10L) === 10L) |
| 36 | + assert(mm.executionMemoryUsed === 10L) |
| 37 | + assert(mm.acquireExecutionMemory(100L) === 100L) |
| 38 | + // Acquire up to the max |
| 39 | + assert(mm.acquireExecutionMemory(1000L) === 890L) |
| 40 | + assert(mm.executionMemoryUsed === 1000L) |
| 41 | + assert(mm.acquireExecutionMemory(1L) === 0L) |
| 42 | + assert(mm.executionMemoryUsed === 1000L) |
| 43 | + mm.releaseExecutionMemory(800L) |
| 44 | + assert(mm.executionMemoryUsed === 200L) |
| 45 | + // Acquire after release |
| 46 | + assert(mm.acquireExecutionMemory(1L) === 1L) |
| 47 | + assert(mm.executionMemoryUsed === 201L) |
| 48 | + // Release beyond what was acquired |
| 49 | + mm.releaseExecutionMemory(maxExecutionMem) |
| 50 | + assert(mm.executionMemoryUsed === 0L) |
| 51 | + } |
| 52 | + |
| 53 | + test("basic storage memory") { |
| 54 | + val maxStorageMem = 1000L |
| 55 | + val dummyBlock = TestBlockId("you can see the world you brought to live") |
| 56 | + val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] |
| 57 | + val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) |
| 58 | + assert(mm.storageMemoryUsed === 0L) |
| 59 | + assert(mm.acquireStorageMemory(dummyBlock, 10L, dummyBlocks) === 10L) |
| 60 | + // `ensureFreeSpace` should be called with the number of bytes requested |
| 61 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 10L) |
| 62 | + assert(mm.storageMemoryUsed === 10L) |
| 63 | + assert(dummyBlocks.isEmpty) |
| 64 | + assert(mm.acquireStorageMemory(dummyBlock, 100L, dummyBlocks) === 100L) |
| 65 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L) |
| 66 | + // Acquire up to the max, not granted |
| 67 | + assert(mm.acquireStorageMemory(dummyBlock, 1000L, dummyBlocks) === 0L) |
| 68 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1000L) |
| 69 | + assert(mm.storageMemoryUsed === 110L) |
| 70 | + assert(mm.acquireStorageMemory(dummyBlock, 890L, dummyBlocks) === 890L) |
| 71 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 890L) |
| 72 | + assert(mm.storageMemoryUsed === 1000L) |
| 73 | + assert(mm.acquireStorageMemory(dummyBlock, 1L, dummyBlocks) === 0L) |
| 74 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) |
| 75 | + assert(mm.storageMemoryUsed === 1000L) |
| 76 | + mm.releaseStorageMemory(800L) |
| 77 | + assert(mm.storageMemoryUsed === 200L) |
| 78 | + // Acquire after release |
| 79 | + assert(mm.acquireStorageMemory(dummyBlock, 1L, dummyBlocks) === 1L) |
| 80 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) |
| 81 | + assert(mm.storageMemoryUsed === 201L) |
| 82 | + mm.releaseStorageMemory() |
| 83 | + assert(mm.storageMemoryUsed === 0L) |
| 84 | + assert(mm.acquireStorageMemory(dummyBlock, 1L, dummyBlocks) === 1L) |
| 85 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) |
| 86 | + assert(mm.storageMemoryUsed === 1L) |
| 87 | + // Release beyond what was acquired |
| 88 | + mm.releaseStorageMemory(100L) |
| 89 | + assert(mm.storageMemoryUsed === 0L) |
| 90 | + } |
| 91 | + |
| 92 | + test("execution and storage isolation") { |
| 93 | + val maxExecutionMem = 200L |
| 94 | + val maxStorageMem = 1000L |
| 95 | + val dummyBlock = TestBlockId("ain't nobody love like you do") |
| 96 | + val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] |
| 97 | + val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) |
| 98 | + // Only execution memory should increase |
| 99 | + assert(mm.acquireExecutionMemory(100L) === 100L) |
| 100 | + assert(mm.storageMemoryUsed === 0L) |
| 101 | + assert(mm.executionMemoryUsed === 100L) |
| 102 | + assert(mm.acquireExecutionMemory(1000L) === 100L) |
| 103 | + assert(mm.storageMemoryUsed === 0L) |
| 104 | + assert(mm.executionMemoryUsed === 200L) |
| 105 | + // Only storage memory should increase |
| 106 | + assert(mm.acquireStorageMemory(dummyBlock, 50L, dummyBlocks) === 50L) |
| 107 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 50L) |
| 108 | + assert(mm.storageMemoryUsed === 50L) |
| 109 | + assert(mm.executionMemoryUsed === 200L) |
| 110 | + // Only execution memory should be released |
| 111 | + mm.releaseExecutionMemory(133L) |
| 112 | + assert(mm.storageMemoryUsed === 50L) |
| 113 | + assert(mm.executionMemoryUsed === 67L) |
| 114 | + // Only storage memory should be released |
| 115 | + mm.releaseStorageMemory() |
| 116 | + assert(mm.storageMemoryUsed === 0L) |
| 117 | + assert(mm.executionMemoryUsed === 67L) |
| 118 | + } |
| 119 | + |
| 120 | + test("unroll memory") { |
| 121 | + val maxStorageMem = 1000L |
| 122 | + val dummyBlock = TestBlockId("lonely water") |
| 123 | + val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] |
| 124 | + val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) |
| 125 | + assert(mm.acquireUnrollMemory(dummyBlock, 100L, dummyBlocks) === 100L) |
| 126 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L) |
| 127 | + assert(mm.storageMemoryUsed === 100L) |
| 128 | + mm.releaseUnrollMemory(40L) |
| 129 | + assert(mm.storageMemoryUsed === 60L) |
| 130 | + when(ms.currentUnrollMemory).thenReturn(60L) |
| 131 | + assert(mm.acquireUnrollMemory(dummyBlock, 500L, dummyBlocks) === 500L) |
| 132 | + // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. |
| 133 | + // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. |
| 134 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 340L) |
| 135 | + assert(mm.storageMemoryUsed === 560L) |
| 136 | + when(ms.currentUnrollMemory).thenReturn(560L) |
| 137 | + assert(mm.acquireUnrollMemory(dummyBlock, 800L, dummyBlocks) === 0L) |
| 138 | + assert(mm.storageMemoryUsed === 560L) |
| 139 | + // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed |
| 140 | + assertEnsureFreeSpaceCalled(ms, dummyBlock, 0L) |
| 141 | + // Release beyond what was acquired |
| 142 | + mm.releaseUnrollMemory(maxStorageMem) |
| 143 | + assert(mm.storageMemoryUsed === 0L) |
| 144 | + } |
| 145 | + |
| 146 | + /** |
| 147 | + * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. |
| 148 | + */ |
| 149 | + private def makeThings( |
| 150 | + maxExecutionMem: Long, |
| 151 | + maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { |
| 152 | + val mm = new StaticMemoryManager(conf, maxExecutionMem, maxStorageMem) |
| 153 | + val ms = mock(classOf[MemoryStore]) |
| 154 | + mm.setMemoryStore(ms) |
| 155 | + (mm, ms) |
| 156 | + } |
| 157 | + |
| 158 | + /** |
| 159 | + * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. |
| 160 | + */ |
| 161 | + private def assertEnsureFreeSpaceCalled( |
| 162 | + ms: MemoryStore, |
| 163 | + blockId: BlockId, |
| 164 | + numBytes: Long): Unit = { |
| 165 | + verify(ms).ensureFreeSpace(meq(blockId), meq(numBytes: java.lang.Long), any()) |
| 166 | + reset(ms) |
| 167 | + } |
| 168 | + |
| 169 | +} |
0 commit comments