Skip to content

Commit c6aaccc

Browse files
committed
Merge branch 'master' into SPARK-29587
2 parents e3857e8 + 456cfe6 commit c6aaccc

File tree

8 files changed

+28
-53
lines changed

8 files changed

+28
-53
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ private[spark] object MapOutputTracker extends Logging {
902902
// the contents don't have to be copied to the new buffer.
903903
val out = new ApacheByteArrayOutputStream()
904904
out.write(DIRECT)
905-
val codec = CompressionCodec.createCodec(conf, "zstd")
905+
val codec = CompressionCodec.createCodec(conf, conf.get(MAP_STATUS_COMPRESSION_CODEC))
906906
val objOut = new ObjectOutputStream(codec.compressedOutputStream(out))
907907
Utils.tryWithSafeFinally {
908908
// Since statuses can be modified in parallel, sync on it
@@ -939,7 +939,7 @@ private[spark] object MapOutputTracker extends Logging {
939939
assert (bytes.length > 0)
940940

941941
def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
942-
val codec = CompressionCodec.createCodec(conf, "zstd")
942+
val codec = CompressionCodec.createCodec(conf, conf.get(MAP_STATUS_COMPRESSION_CODEC))
943943
// The ZStd codec is wrapped in a `BufferedInputStream` which avoids overhead excessive
944944
// of JNI call while trying to decompress small amount of data for each element
945945
// of `MapStatuses`

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,15 @@ package object config {
10161016
.booleanConf
10171017
.createWithDefault(true)
10181018

1019+
private[spark] val MAP_STATUS_COMPRESSION_CODEC =
1020+
ConfigBuilder("spark.shuffle.mapStatus.compression.codec")
1021+
.internal()
1022+
.doc("The codec used to compress MapStatus, which is generated by ShuffleMapTask. " +
1023+
"By default, Spark provides four codecs: lz4, lzf, snappy, and zstd. You can also " +
1024+
"use fully qualified class names to specify the codec.")
1025+
.stringConf
1026+
.createWithDefault("zstd")
1027+
10191028
private[spark] val SHUFFLE_SPILL_INITIAL_MEM_THRESHOLD =
10201029
ConfigBuilder("spark.shuffle.spill.initialMemoryThreshold")
10211030
.internal()

pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,6 +3013,10 @@
30133013

30143014
<profile>
30153015
<id>scala-2.13</id>
3016+
<properties>
3017+
<scala.version>2.13.1</scala.version>
3018+
<scala.binary.version>2.13</scala.binary.version>
3019+
</properties>
30163020
<dependencyManagement>
30173021
<dependencies>
30183022
<dependency>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
447447
}
448448

449449
private[this] def decimalToTimestamp(d: Decimal): Long = {
450-
(d.toBigDecimal * MICROS_PER_SECOND).longValue()
450+
(d.toBigDecimal * MICROS_PER_SECOND).longValue
451451
}
452452
private[this] def doubleToTimestamp(d: Double): Any = {
453453
if (d.isNaN || d.isInfinite) null else (d * MICROS_PER_SECOND).toLong
@@ -632,7 +632,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
632632
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
633633
case StringType =>
634634
buildCast[UTF8String](_, s => try {
635-
changePrecision(Decimal(new JavaBigDecimal(s.toString)), target)
635+
// According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`.
636+
// Please refer to https://github.com/apache/spark/pull/26640
637+
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
636638
} catch {
637639
case _: NumberFormatException => null
638640
})
@@ -1128,7 +1130,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
11281130
(c, evPrim, evNull) =>
11291131
code"""
11301132
try {
1131-
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString()));
1133+
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
11321134
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
11331135
} catch (java.lang.NumberFormatException e) {
11341136
$evNull = true;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,9 +1934,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
19341934
override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
19351935
BigDecimal(ctx.getText) match {
19361936
case v if v.isValidInt =>
1937-
Literal(v.intValue())
1937+
Literal(v.intValue)
19381938
case v if v.isValidLong =>
1939-
Literal(v.longValue())
1939+
Literal(v.longValue)
19401940
case v => Literal(v.underlying())
19411941
}
19421942
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ object EstimationUtils {
6363
}
6464
}
6565

66-
def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt()
66+
def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt
6767

6868
/** Get column stats for output attributes. */
6969
def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute])

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
188188

189189
def toScalaBigInt: BigInt = {
190190
if (decimalVal.ne(null)) {
191-
decimalVal.toBigInt()
191+
decimalVal.toBigInt
192192
} else {
193193
BigInt(toLong)
194194
}
@@ -220,15 +220,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
220220
}
221221
}
222222

223-
def toDouble: Double = toBigDecimal.doubleValue()
223+
def toDouble: Double = toBigDecimal.doubleValue
224224

225-
def toFloat: Float = toBigDecimal.floatValue()
225+
def toFloat: Float = toBigDecimal.floatValue
226226

227227
def toLong: Long = {
228228
if (decimalVal.eq(null)) {
229229
longVal / POW_10(_scale)
230230
} else {
231-
decimalVal.longValue()
231+
decimalVal.longValue
232232
}
233233
}
234234

sql/core/src/test/resources/sql-tests/results/cast.sql.out

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 48
2+
-- Number of queries: 43
33

44

55
-- !query 0
@@ -350,44 +350,4 @@ select cast('1.0 ' as DEC)
350350
-- !query 42 schema
351351
struct<CAST(1.0 AS DECIMAL(10,0)):decimal(10,0)>
352352
-- !query 42 output
353-
NULL
354-
355-
356-
-- !query 43
357-
select cast(1.0 as real)
358-
-- !query 43 schema
359-
struct<CAST(1.0 AS FLOAT):float>
360-
-- !query 43 output
361-
1.0
362-
363-
364-
-- !query 44
365-
select cast('1' as real)
366-
-- !query 44 schema
367-
struct<CAST(1 AS FLOAT):float>
368-
-- !query 44 output
369-
1.0
370-
371-
372-
-- !query 45
373-
select cast(1.0 as numeric)
374-
-- !query 45 schema
375-
struct<CAST(1.0 AS DECIMAL(10,0)):decimal(10,0)>
376-
-- !query 45 output
377353
1
378-
379-
380-
-- !query 46
381-
select cast(1.0 as numeric(3))
382-
-- !query 46 schema
383-
struct<CAST(1.0 AS DECIMAL(3,0)):decimal(3,0)>
384-
-- !query 46 output
385-
1
386-
387-
388-
-- !query 47
389-
select cast(1.08 as numeric(3,1))
390-
-- !query 47 schema
391-
struct<CAST(1.08 AS DECIMAL(3,1)):decimal(3,1)>
392-
-- !query 47 output
393-
1.1

0 commit comments

Comments
 (0)