1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one or more
3
+ * contributor license agreements. See the NOTICE file distributed with
4
+ * this work for additional information regarding copyright ownership.
5
+ * The ASF licenses this file to You under the Apache License, Version 2.0
6
+ * (the "License"); you may not use this file except in compliance with
7
+ * the License. You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ package org .apache .spark .util .collection .unsafe .sort ;
19
+
20
+ import com .google .common .annotations .VisibleForTesting ;
21
+ import org .apache .spark .SparkConf ;
22
+ import org .apache .spark .TaskContext ;
23
+ import org .apache .spark .executor .ShuffleWriteMetrics ;
24
+ import org .apache .spark .shuffle .ShuffleMemoryManager ;
25
+ import org .apache .spark .storage .BlockManager ;
26
+ import org .apache .spark .unsafe .PlatformDependent ;
27
+ import org .apache .spark .unsafe .memory .MemoryBlock ;
28
+ import org .apache .spark .unsafe .memory .TaskMemoryManager ;
29
+ import org .slf4j .Logger ;
30
+ import org .slf4j .LoggerFactory ;
31
+
32
+ import java .io .IOException ;
33
+ import java .util .Iterator ;
34
+ import java .util .LinkedList ;
35
+
36
+ /**
37
+ * External sorter based on {@link UnsafeInMemorySorter}.
38
+ */
39
+ public final class UnsafeExternalSorter {
40
+
41
+ private final Logger logger = LoggerFactory .getLogger (UnsafeExternalSorter .class );
42
+
43
+ private static final int PAGE_SIZE = 1024 * 1024 ; // TODO: tune this
44
+
45
+ private final PrefixComparator prefixComparator ;
46
+ private final RecordComparator recordComparator ;
47
+ private final int initialSize ;
48
+ private int numSpills = 0 ;
49
+ private UnsafeInMemorySorter sorter ;
50
+
51
+ private final TaskMemoryManager memoryManager ;
52
+ private final ShuffleMemoryManager shuffleMemoryManager ;
53
+ private final BlockManager blockManager ;
54
+ private final TaskContext taskContext ;
55
+ private final LinkedList <MemoryBlock > allocatedPages = new LinkedList <MemoryBlock >();
56
+ private final boolean spillingEnabled ;
57
+ private final int fileBufferSize ;
58
+ private ShuffleWriteMetrics writeMetrics ;
59
+
60
+
61
+ private MemoryBlock currentPage = null ;
62
+ private long currentPagePosition = -1 ;
63
+
64
+ private final LinkedList <UnsafeSorterSpillWriter > spillWriters =
65
+ new LinkedList <UnsafeSorterSpillWriter >();
66
+
67
+ public UnsafeExternalSorter (
68
+ TaskMemoryManager memoryManager ,
69
+ ShuffleMemoryManager shuffleMemoryManager ,
70
+ BlockManager blockManager ,
71
+ TaskContext taskContext ,
72
+ RecordComparator recordComparator ,
73
+ PrefixComparator prefixComparator ,
74
+ int initialSize ,
75
+ SparkConf conf ) throws IOException {
76
+ this .memoryManager = memoryManager ;
77
+ this .shuffleMemoryManager = shuffleMemoryManager ;
78
+ this .blockManager = blockManager ;
79
+ this .taskContext = taskContext ;
80
+ this .recordComparator = recordComparator ;
81
+ this .prefixComparator = prefixComparator ;
82
+ this .initialSize = initialSize ;
83
+ this .spillingEnabled = conf .getBoolean ("spark.shuffle.spill" , true );
84
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
85
+ this .fileBufferSize = (int ) conf .getSizeAsKb ("spark.shuffle.file.buffer" , "32k" ) * 1024 ;
86
+ openSorter ();
87
+ }
88
+
89
+ // TODO: metrics tracking + integration with shuffle write metrics
90
+
91
+ private void openSorter () throws IOException {
92
+ this .writeMetrics = new ShuffleWriteMetrics ();
93
+ // TODO: connect write metrics to task metrics?
94
+ // TODO: move this sizing calculation logic into a static method of sorter:
95
+ final long memoryRequested = initialSize * 8L * 2 ;
96
+ if (spillingEnabled ) {
97
+ final long memoryAcquired = shuffleMemoryManager .tryToAcquire (memoryRequested );
98
+ if (memoryAcquired != memoryRequested ) {
99
+ shuffleMemoryManager .release (memoryAcquired );
100
+ throw new IOException ("Could not acquire memory!" );
101
+ }
102
+ }
103
+
104
+ this .sorter =
105
+ new UnsafeInMemorySorter (memoryManager , recordComparator , prefixComparator , initialSize );
106
+ }
107
+
108
+ @ VisibleForTesting
109
+ public void spill () throws IOException {
110
+ final UnsafeSorterSpillWriter spillWriter =
111
+ new UnsafeSorterSpillWriter (blockManager , fileBufferSize , writeMetrics );
112
+ spillWriters .add (spillWriter );
113
+ final UnsafeSorterIterator sortedRecords = sorter .getSortedIterator ();
114
+ while (sortedRecords .hasNext ()) {
115
+ sortedRecords .loadNext ();
116
+ final Object baseObject = sortedRecords .getBaseObject ();
117
+ final long baseOffset = sortedRecords .getBaseOffset ();
118
+ // TODO: this assumption that the first long holds a length is not enforced via our interfaces
119
+ // We need to either always store this via the write path (e.g. not require the caller to do
120
+ // it), or provide interfaces / hooks for customizing the physical storage format etc.
121
+ final int recordLength = (int ) PlatformDependent .UNSAFE .getLong (baseObject , baseOffset );
122
+ spillWriter .write (baseObject , baseOffset , recordLength , sortedRecords .getKeyPrefix ());
123
+ }
124
+ spillWriter .close ();
125
+ final long sorterMemoryUsage = sorter .getMemoryUsage ();
126
+ sorter = null ;
127
+ shuffleMemoryManager .release (sorterMemoryUsage );
128
+ final long spillSize = freeMemory ();
129
+ taskContext .taskMetrics ().incMemoryBytesSpilled (spillSize );
130
+ taskContext .taskMetrics ().incDiskBytesSpilled (spillWriter .numberOfSpilledBytes ());
131
+ numSpills ++;
132
+ final long threadId = Thread .currentThread ().getId ();
133
+ // TODO: messy; log _before_ spill
134
+ logger .info ("Thread " + threadId + " spilling in-memory map of " +
135
+ org .apache .spark .util .Utils .bytesToString (spillSize ) + " to disk (" +
136
+ (numSpills + ((numSpills > 1 ) ? " times" : " time" )) + " so far)" );
137
+ openSorter ();
138
+ }
139
+
140
+ private long freeMemory () {
141
+ long memoryFreed = 0 ;
142
+ final Iterator <MemoryBlock > iter = allocatedPages .iterator ();
143
+ while (iter .hasNext ()) {
144
+ memoryManager .freePage (iter .next ());
145
+ shuffleMemoryManager .release (PAGE_SIZE );
146
+ memoryFreed += PAGE_SIZE ;
147
+ iter .remove ();
148
+ }
149
+ currentPage = null ;
150
+ currentPagePosition = -1 ;
151
+ return memoryFreed ;
152
+ }
153
+
154
+ private void ensureSpaceInDataPage (int requiredSpace ) throws Exception {
155
+ // TODO: merge these steps to first calculate total memory requirements for this insert,
156
+ // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
157
+ // data page.
158
+ if (!sorter .hasSpaceForAnotherRecord () && spillingEnabled ) {
159
+ final long oldSortBufferMemoryUsage = sorter .getMemoryUsage ();
160
+ final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2 ;
161
+ final long memoryAcquired = shuffleMemoryManager .tryToAcquire (memoryToGrowSortBuffer );
162
+ if (memoryAcquired < memoryToGrowSortBuffer ) {
163
+ shuffleMemoryManager .release (memoryAcquired );
164
+ spill ();
165
+ } else {
166
+ sorter .expandSortBuffer ();
167
+ shuffleMemoryManager .release (oldSortBufferMemoryUsage );
168
+ }
169
+ }
170
+
171
+ final long spaceInCurrentPage ;
172
+ if (currentPage != null ) {
173
+ spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage .getBaseOffset ());
174
+ } else {
175
+ spaceInCurrentPage = 0 ;
176
+ }
177
+ if (requiredSpace > PAGE_SIZE ) {
178
+ // TODO: throw a more specific exception?
179
+ throw new Exception ("Required space " + requiredSpace + " is greater than page size (" +
180
+ PAGE_SIZE + ")" );
181
+ } else if (requiredSpace > spaceInCurrentPage ) {
182
+ if (spillingEnabled ) {
183
+ final long memoryAcquired = shuffleMemoryManager .tryToAcquire (PAGE_SIZE );
184
+ if (memoryAcquired < PAGE_SIZE ) {
185
+ shuffleMemoryManager .release (memoryAcquired );
186
+ spill ();
187
+ final long memoryAcquiredAfterSpill = shuffleMemoryManager .tryToAcquire (PAGE_SIZE );
188
+ if (memoryAcquiredAfterSpill != PAGE_SIZE ) {
189
+ shuffleMemoryManager .release (memoryAcquiredAfterSpill );
190
+ throw new Exception ("Can't allocate memory!" );
191
+ }
192
+ }
193
+ }
194
+ currentPage = memoryManager .allocatePage (PAGE_SIZE );
195
+ currentPagePosition = currentPage .getBaseOffset ();
196
+ allocatedPages .add (currentPage );
197
+ logger .info ("Acquired new page! " + allocatedPages .size () * PAGE_SIZE );
198
+ }
199
+ }
200
+
201
+ public void insertRecord (
202
+ Object recordBaseObject ,
203
+ long recordBaseOffset ,
204
+ int lengthInBytes ,
205
+ long prefix ) throws Exception {
206
+ // Need 4 bytes to store the record length.
207
+ ensureSpaceInDataPage (lengthInBytes + 4 );
208
+
209
+ final long recordAddress =
210
+ memoryManager .encodePageNumberAndOffset (currentPage , currentPagePosition );
211
+ final Object dataPageBaseObject = currentPage .getBaseObject ();
212
+ PlatformDependent .UNSAFE .putInt (dataPageBaseObject , currentPagePosition , lengthInBytes );
213
+ currentPagePosition += 4 ;
214
+ PlatformDependent .copyMemory (
215
+ recordBaseObject ,
216
+ recordBaseOffset ,
217
+ dataPageBaseObject ,
218
+ currentPagePosition ,
219
+ lengthInBytes );
220
+ currentPagePosition += lengthInBytes ;
221
+
222
+ sorter .insertRecord (recordAddress , prefix );
223
+ }
224
+
225
+ public UnsafeSorterIterator getSortedIterator () throws IOException {
226
+ final UnsafeSorterSpillMerger spillMerger =
227
+ new UnsafeSorterSpillMerger (recordComparator , prefixComparator );
228
+ for (UnsafeSorterSpillWriter spillWriter : spillWriters ) {
229
+ spillMerger .addSpill (spillWriter .getReader (blockManager ));
230
+ }
231
+ spillWriters .clear ();
232
+ spillMerger .addSpill (sorter .getSortedIterator ());
233
+ return spillMerger .getSortedIterator ();
234
+ }
235
+ }
0 commit comments