Skip to content

Commit 211331c

Browse files
committed
WIP: in-memory columnar compression support
1 parent 85cc59b commit 211331c

11 files changed

+487
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,29 @@ private[sql] trait ColumnAccessor {
4141
}
4242

4343
private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
44-
buffer: ByteBuffer, columnType: ColumnType[T, JvmType])
44+
protected val buffer: ByteBuffer,
45+
protected val columnType: ColumnType[T, JvmType])
4546
extends ColumnAccessor {
4647

4748
protected def initialize() {}
4849

4950
def hasNext = buffer.hasRemaining
5051

5152
def extractTo(row: MutableRow, ordinal: Int) {
52-
columnType.setField(row, ordinal, columnType.extract(buffer))
53+
columnType.setField(row, ordinal, extractSingle(buffer))
5354
}
5455

56+
def extractSingle(buffer: ByteBuffer) = columnType.extract(buffer)
57+
5558
protected def underlyingBuffer = buffer
5659
}
5760

5861
private[sql] abstract class NativeColumnAccessor[T <: NativeType](
5962
buffer: ByteBuffer,
60-
val columnType: NativeColumnType[T])
61-
extends BasicColumnAccessor[T, T#JvmType](buffer, columnType)
62-
with NullableColumnAccessor {
63-
64-
type JvmType = T#JvmType
65-
}
63+
columnType: NativeColumnType[T])
64+
extends BasicColumnAccessor(buffer, columnType)
65+
with NullableColumnAccessor
66+
with CompressedColumnAccessor[T]
6667

6768
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
6869
extends NativeColumnAccessor(buffer, BOOLEAN)

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,36 @@ private[sql] trait ColumnBuilder {
2929
*/
3030
def initialize(initialSize: Int, columnName: String = "")
3131

32+
/**
33+
* Gathers statistics information from `row(ordinal)`.
34+
*/
3235
def gatherStats(row: Row, ordinal: Int) {}
3336

37+
/**
38+
* Appends `row(ordinal)` to the column builder.
39+
*/
3440
def appendFrom(row: Row, ordinal: Int)
3541

42+
/**
43+
* Returns the final columnar byte buffer.
44+
*/
3645
def build(): ByteBuffer
3746
}
3847

3948
private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType](
4049
val columnType: ColumnType[T, JvmType])
4150
extends ColumnBuilder {
4251

43-
private var columnName: String = _
52+
protected var columnName: String = _
4453

4554
protected var buffer: ByteBuffer = _
4655

4756
override def initialize(initialSize: Int, columnName: String = "") = {
4857
val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize
4958
this.columnName = columnName
50-
buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize)
59+
60+
// Reserves 4 bytes for column type ID
61+
buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize)
5162
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
5263
}
5364

@@ -66,8 +77,9 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType](
6677
private[sql] abstract class NativeColumnBuilder[T <: NativeType](
6778
protected val columnStats: ColumnStats[T],
6879
columnType: NativeColumnType[T])
69-
extends BasicColumnBuilder[T, T#JvmType](columnType)
70-
with NullableColumnBuilder {
80+
extends BasicColumnBuilder(columnType)
81+
with NullableColumnBuilder
82+
with CompressedColumnBuilder[T] {
7183

7284
override def gatherStats(row: Row, ordinal: Int) {
7385
columnStats.gatherStats(row, ordinal)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.columnar
19+
20+
import org.apache.spark.sql.Row
21+
import org.apache.spark.sql.catalyst.types._
22+
23+
private[sql] sealed abstract class ColumnStats[T <: NativeType] extends Serializable{
24+
type JvmType = T#JvmType
25+
26+
protected var (_lower, _upper) = initialBounds
27+
28+
protected val ordering: Ordering[JvmType]
29+
30+
protected def columnType: NativeColumnType[T]
31+
32+
/**
33+
* Closed lower bound of this column.
34+
*/
35+
def lowerBound = _lower
36+
37+
/**
38+
* Closed upper bound of this column.
39+
*/
40+
def upperBound = _upper
41+
42+
/**
43+
* Initial values for the closed lower/upper bounds, in the format of `(lower, upper)`.
44+
*/
45+
protected def initialBounds: (JvmType, JvmType)
46+
47+
/**
48+
* Gathers statistics information from `row(ordinal)`.
49+
*/
50+
@inline def gatherStats(row: Row, ordinal: Int) {
51+
val field = columnType.getField(row, ordinal)
52+
if (ordering.gt(field, upperBound)) _upper = field
53+
if (ordering.lt(field, lowerBound)) _lower = field
54+
}
55+
56+
/**
57+
* Returns `true` if `lower <= row(ordinal) <= upper`.
58+
*/
59+
@inline def contains(row: Row, ordinal: Int) = {
60+
val field = columnType.getField(row, ordinal)
61+
ordering.lteq(lowerBound, field) && ordering.lteq(field, upperBound)
62+
}
63+
64+
/**
65+
* Returns `true` if `row(ordinal) < upper` holds.
66+
*/
67+
@inline def isAbove(row: Row, ordinal: Int) = {
68+
val field = columnType.getField(row, ordinal)
69+
ordering.lt(field, upperBound)
70+
}
71+
72+
/**
73+
* Returns `true` if `lower < row(ordinal)` holds.
74+
*/
75+
@inline def isBelow(row: Row, ordinal: Int) = {
76+
val field = columnType.getField(row, ordinal)
77+
ordering.lt(lowerBound, field)
78+
}
79+
80+
/**
81+
* Returns `true` if `row(ordinal) <= upper` holds.
82+
*/
83+
@inline def isAtOrAbove(row: Row, ordinal: Int) = {
84+
contains(row, ordinal) || isAbove(row, ordinal)
85+
}
86+
87+
/**
88+
* Returns `true` if `lower <= row(ordinal)` holds.
89+
*/
90+
@inline def isAtOrBelow(row: Row, ordinal: Int) = {
91+
contains(row, ordinal) || isBelow(row, ordinal)
92+
}
93+
}
94+
95+
private[sql] abstract class BasicColumnStats[T <: NativeType](
96+
protected val columnType: NativeColumnType[T])
97+
extends ColumnStats[T]
98+
99+
private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
100+
override protected val ordering = implicitly[Ordering[JvmType]]
101+
override protected def initialBounds = (true, false)
102+
}
103+
104+
private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
105+
override protected val ordering = implicitly[Ordering[JvmType]]
106+
override protected def initialBounds = (Byte.MaxValue, Byte.MinValue)
107+
}
108+
109+
private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
110+
override protected val ordering = implicitly[Ordering[JvmType]]
111+
override protected def initialBounds = (Short.MaxValue, Short.MinValue)
112+
}
113+
114+
private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
115+
override protected val ordering = implicitly[Ordering[JvmType]]
116+
override protected def initialBounds = (Long.MaxValue, Long.MinValue)
117+
}
118+
119+
private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
120+
override protected val ordering = implicitly[Ordering[JvmType]]
121+
override protected def initialBounds = (Double.MaxValue, Double.MinValue)
122+
}
123+
124+
private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
125+
override protected val ordering = implicitly[Ordering[JvmType]]
126+
override protected def initialBounds = (Float.MaxValue, Float.MinValue)
127+
}
128+
129+
private[sql] class IntColumnStats extends BasicColumnStats(INT) {
130+
private object OrderedState extends Enumeration {
131+
val Uninitialized, Initialized, Ascending, Descending, Unordered = Value
132+
}
133+
134+
import OrderedState._
135+
136+
private var orderedState = Uninitialized
137+
private var lastValue: Int = _
138+
private var _maxDelta: Int = _
139+
140+
def isAscending = orderedState != Descending && orderedState != Unordered
141+
def isDescending = orderedState != Ascending && orderedState != Unordered
142+
def isOrdered = isAscending || isDescending
143+
def maxDelta = _maxDelta
144+
145+
override protected val ordering = implicitly[Ordering[JvmType]]
146+
override protected def initialBounds = (Int.MaxValue, Int.MinValue)
147+
148+
override def gatherStats(row: Row, ordinal: Int) = {
149+
val field = columnType.getField(row, ordinal)
150+
151+
if (field > upperBound) _upper = field
152+
if (field < lowerBound) _lower = field
153+
154+
orderedState = orderedState match {
155+
case Uninitialized =>
156+
lastValue = field
157+
Initialized
158+
159+
case Initialized =>
160+
// If all the integers in the column are the same, ordered state is set to Ascending.
161+
// TODO (lian) Confirm whether this is the standard behaviour.
162+
val nextState = if (field >= lastValue) Ascending else Descending
163+
_maxDelta = math.abs(field - lastValue)
164+
lastValue = field
165+
nextState
166+
167+
case Ascending if field < lastValue =>
168+
Unordered
169+
170+
case Descending if field > lastValue =>
171+
Unordered
172+
173+
case state @ (Ascending | Descending) =>
174+
_maxDelta = _maxDelta.max(field - lastValue)
175+
lastValue = field
176+
state
177+
}
178+
}
179+
}
180+
181+
private[sql] class StringColumnStates extends BasicColumnStats(STRING) {
182+
override protected val ordering = implicitly[Ordering[JvmType]]
183+
override protected def initialBounds = (null, null)
184+
185+
override def contains(row: Row, ordinal: Int) = {
186+
!(upperBound eq null) && super.contains(row, ordinal)
187+
}
188+
189+
override def isAbove(row: Row, ordinal: Int) = {
190+
!(upperBound eq null) && super.isAbove(row, ordinal)
191+
}
192+
193+
override def isBelow(row: Row, ordinal: Int) = {
194+
!(lowerBound eq null) && super.isBelow(row, ordinal)
195+
}
196+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.columnar
19+
20+
import java.nio.ByteBuffer
21+
22+
import org.apache.spark.sql.catalyst.types.NativeType
23+
import org.apache.spark.sql.columnar.CompressionAlgorithm.NoopDecoder
24+
import org.apache.spark.sql.columnar.CompressionType._
25+
26+
private[sql] trait CompressedColumnAccessor[T <: NativeType] extends ColumnAccessor {
27+
this: BasicColumnAccessor[T, T#JvmType] =>
28+
29+
private var decoder: Iterator[T#JvmType] = _
30+
31+
abstract override protected def initialize() = {
32+
super.initialize()
33+
34+
decoder = underlyingBuffer.getInt() match {
35+
case id if id == Noop.id => new NoopDecoder[T](buffer, columnType)
36+
case _ => throw new UnsupportedOperationException()
37+
}
38+
}
39+
40+
abstract override def extractSingle(buffer: ByteBuffer) = decoder.next()
41+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.columnar
19+
20+
import org.apache.spark.sql.{Logging, Row}
21+
import org.apache.spark.sql.catalyst.types.NativeType
22+
23+
private[sql] trait CompressedColumnBuilder[T <: NativeType] extends ColumnBuilder with Logging {
24+
this: BasicColumnBuilder[T, T#JvmType] =>
25+
26+
val compressionSchemes = Seq(new CompressionAlgorithm.Noop)
27+
.filter(_.supports(columnType))
28+
29+
def isWorthCompressing(scheme: CompressionAlgorithm) = {
30+
scheme.compressionRatio < 0.8
31+
}
32+
33+
abstract override def gatherStats(row: Row, ordinal: Int) {
34+
compressionSchemes.foreach {
35+
val field = columnType.getField(row, ordinal)
36+
_.gatherCompressibilityStats(field, columnType)
37+
}
38+
39+
super.gatherStats(row, ordinal)
40+
}
41+
42+
abstract override def build() = {
43+
val rawBuffer = super.build()
44+
45+
if (compressionSchemes.isEmpty) {
46+
logger.info(s"Compression scheme chosen for [$columnName] is ${CompressionType.Noop}")
47+
new CompressionAlgorithm.Noop().compress(rawBuffer, columnType)
48+
} else {
49+
val candidateScheme = compressionSchemes.minBy(_.compressionRatio)
50+
51+
logger.info(
52+
s"Compression scheme chosen for [$columnName] is ${candidateScheme.compressionType} " +
53+
s"ration ${candidateScheme.compressionRatio}")
54+
55+
if (isWorthCompressing(candidateScheme)) {
56+
candidateScheme.compress(rawBuffer, columnType)
57+
} else {
58+
new CompressionAlgorithm.Noop().compress(rawBuffer, columnType)
59+
}
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)