@@ -70,15 +70,28 @@ private[memory] class ExecutionMemoryPool(
70
70
* active tasks) before it is forced to spill. This can happen if the number of tasks increase
71
71
* but an older task had a lot of memory already.
72
72
*
73
+ * @param numBytes number of bytes to acquire
74
+ * @param taskAttemptId the task attempt acquiring memory
75
+ * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in
76
+ * one parameter (Long) that represents the desired amount of memory by
77
+ * which this pool should be expanded.
78
+ * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool
79
+ * at this given moment. This is not a field because the max pool
80
+ * size is variable in certain cases. For instance, in unified
81
+ * memory management, the execution pool can be expanded by evicting
82
+ * cached blocks, thereby shrinking the storage pool.
83
+ *
73
84
* @return the number of bytes granted to the task.
74
85
*/
75
- def acquireMemory (
86
+ private [memory] def acquireMemory (
76
87
numBytes : Long ,
77
88
taskAttemptId : Long ,
78
- maybeResizePool : Long => Unit = (_ : Long ) => Unit ,
79
- computeDaviesThingMax : () => Long = () => poolSize): Long = lock.synchronized {
89
+ maybeGrowPool : Long => Unit = (additionalSpaceNeeded : Long ) => Unit ,
90
+ computeMaxPoolSize : () => Long = () => poolSize): Long = lock.synchronized {
80
91
assert(numBytes > 0 , s " invalid number of bytes requested: $numBytes" )
81
92
93
+ // TODO: clean up this clunky method signature
94
+
82
95
// Add this task to the taskMemory map just so we can keep an accurate count of the number
83
96
// of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory`
84
97
if (! memoryForTask.contains(taskAttemptId)) {
@@ -95,23 +108,30 @@ private[memory] class ExecutionMemoryPool(
95
108
val numActiveTasks = memoryForTask.keys.size
96
109
val curMem = memoryForTask(taskAttemptId)
97
110
98
- // TODO: explain me
99
- maybeResizePool(numBytes - memoryFree)
100
-
101
- // TODO: explain me
102
- val daviesThingMax = computeDaviesThingMax()
103
-
104
- // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
105
- // don't let it be negative
106
- val maxToGrant = math.min(numBytes, math.max(0 , (daviesThingMax / numActiveTasks) - curMem))
111
+ // In every iteration of this loop, we should first try to reclaim any borrowed execution
112
+ // space from storage. This is necessary because of the potential race condition where new
113
+ // storage blocks may steal the free execution memory that this task was waiting for.
114
+ maybeGrowPool(numBytes - memoryFree)
115
+
116
+ // Maximum size the pool would have after potentially growing the pool.
117
+ // This is used to compute the upper bound of how much memory each task can occupy. This
118
+ // must take into account potential free memory as well as the amount this pool currently
119
+ // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management,
120
+ // we did not take into account space that could have been freed by evicting cached blocks.
121
+ val maxPoolSize = computeMaxPoolSize()
122
+ val maxMemoryPerTask = maxPoolSize / numActiveTasks
123
+ val minMemoryPerTask = poolSize / (2 * numActiveTasks)
124
+
125
+ // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks
126
+ val maxToGrant = math.min(numBytes, math.max(0 , maxMemoryPerTask - curMem))
107
127
// Only give it as much memory as is free, which might be none if it reached 1 / numTasks
108
128
val toGrant = math.min(maxToGrant, memoryFree)
109
129
110
- if (curMem < poolSize / ( 2 * numActiveTasks) ) {
130
+ if (curMem < minMemoryPerTask ) {
111
131
// We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
112
132
// if we can't give it this much now, wait for other tasks to free up memory
113
133
// (this happens if older tasks allocated lots of memory before N grew)
114
- if (memoryFree >= math.min(maxToGrant, poolSize / ( 2 * numActiveTasks) - curMem )) {
134
+ if (memoryFree >= math.min(maxToGrant, poolSize / minMemoryPerTask )) {
115
135
memoryForTask(taskAttemptId) += toGrant
116
136
return toGrant
117
137
} else {
0 commit comments