Skip to content

Commit 52274f7

Browse files
add support for udf_format_number and length for binary
1 parent affbe32 commit 52274f7

File tree

5 files changed

+238
-36
lines changed

5 files changed

+238
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,12 @@ object FunctionRegistry {
152152
expression[Base64]("base64"),
153153
expression[Encode]("encode"),
154154
expression[Decode]("decode"),
155-
expression[StringInstr]("instr"),
155+
expression[FormatNumber]("format_number"),
156156
expression[Lower]("lcase"),
157157
expression[Lower]("lower"),
158-
expression[StringLength]("length"),
158+
expression[Length]("length"),
159159
expression[Levenshtein]("levenshtein"),
160+
expression[StringInstr]("instr"),
160161
expression[StringLocate]("locate"),
161162
expression[StringLPad]("lpad"),
162163
expression[StringTrimLeft]("ltrim"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.text.DecimalFormat
2021
import java.util.Locale
2122
import java.util.regex.Pattern
2223

23-
import org.apache.commons.lang3.StringUtils
24-
2524
import org.apache.spark.sql.catalyst.InternalRow
2625
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2726
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -553,17 +552,23 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
553552
}
554553

555554
/**
556-
* A function that return the length of the given string expression.
555+
* A function that return the length of the given string or binary expression.
557556
*/
558-
case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
557+
case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
559558
override def dataType: DataType = IntegerType
560-
override def inputTypes: Seq[DataType] = Seq(StringType)
559+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
561560

562-
protected override def nullSafeEval(string: Any): Any =
563-
string.asInstanceOf[UTF8String].numChars
561+
protected override def nullSafeEval(value: Any): Any = child.dataType match {
562+
case StringType => value.asInstanceOf[UTF8String].numChars
563+
case BinaryType => value.asInstanceOf[Array[Byte]].length
564+
}
564565

565566
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
566-
defineCodeGen(ctx, ev, c => s"($c).numChars()")
567+
child.dataType match {
568+
case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
569+
case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
570+
case NullType => defineCodeGen(ctx, ev, c => s"-1")
571+
}
567572
}
568573

569574
override def prettyName: String = "length"
@@ -668,3 +673,74 @@ case class Encode(value: Expression, charset: Expression)
668673
}
669674
}
670675

676+
/**
677+
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
678+
* and returns the result as a string. If D is 0, the result has no decimal point or
679+
* fractional part.
680+
*/
681+
case class FormatNumber(x: Expression, d: Expression)
682+
extends BinaryExpression with ExpectsInputTypes {
683+
684+
override def left: Expression = x
685+
override def right: Expression = d
686+
override def dataType: DataType = StringType
687+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
688+
override def foldable: Boolean = x.foldable && d.foldable
689+
override def nullable: Boolean = x.nullable || d.nullable
690+
691+
@transient
692+
private var lastDValue: Int = -100
693+
694+
@transient
695+
private val pattern: StringBuffer = new StringBuffer()
696+
697+
@transient
698+
private val numberFormat: DecimalFormat = new DecimalFormat("")
699+
700+
override def eval(input: InternalRow): Any = {
701+
val xObject = x.eval(input)
702+
if (xObject == null) {
703+
return null
704+
}
705+
706+
val dObject = d.eval(input)
707+
708+
if (dObject == null || dObject.asInstanceOf[Int] < 0) {
709+
throw new IllegalArgumentException(
710+
s"Argument 2 of function FORMAT_NUMBER must be >= 0, but $dObject was found")
711+
}
712+
val dValue = dObject.asInstanceOf[Int]
713+
714+
if (dValue != lastDValue) {
715+
// construct a new DecimalFormat only if a new dValue
716+
pattern.delete(0, pattern.length())
717+
pattern.append("#,###,###,###,###,###,##0")
718+
719+
// decimal place
720+
if (dValue > 0) {
721+
pattern.append(".")
722+
723+
var i = 0
724+
while (i < dValue) {
725+
i += 1
726+
pattern.append("0")
727+
}
728+
}
729+
val dFormat = new DecimalFormat(pattern.toString())
730+
lastDValue = dValue;
731+
numberFormat.applyPattern(dFormat.toPattern())
732+
}
733+
734+
x.dataType match {
735+
case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
736+
case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
737+
case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
738+
case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
739+
case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
740+
case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
741+
case _: DecimalType =>
742+
UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
743+
}
744+
}
745+
}
746+

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
22-
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
22+
import org.apache.spark.sql.types._
2323

2424

2525
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
216216
}
217217
}
218218

219-
test("length for string") {
220-
val a = 'a.string.at(0)
221-
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
222-
checkEvaluation(StringLength(a), 5, create_row("abdef"))
223-
checkEvaluation(StringLength(a), 0, create_row(""))
224-
checkEvaluation(StringLength(a), null, create_row(null))
225-
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
226-
}
227-
228219
test("ascii for string") {
229220
val a = 'a.string.at(0)
230221
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
@@ -426,4 +417,47 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
426417
checkEvaluation(
427418
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
428419
}
420+
421+
test("length for string / binary") {
422+
val a = 'a.string.at(0)
423+
val b = 'b.binary.at(0)
424+
val bytes = Array[Byte](1, 2, 3, 1, 2)
425+
val string = "abdef"
426+
427+
// scalastyle:off
428+
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
429+
checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
430+
// scalastyle:on
431+
checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))
432+
433+
checkEvaluation(Length(a), 5, create_row(string))
434+
checkEvaluation(Length(b), 5, create_row(bytes))
435+
436+
checkEvaluation(Length(a), 0, create_row(""))
437+
checkEvaluation(Length(b), 0, create_row(Array[Byte]()))
438+
439+
checkEvaluation(Length(a), null, create_row(null))
440+
checkEvaluation(Length(b), null, create_row(null))
441+
442+
checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
443+
checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
444+
445+
checkEvaluation(Length(Literal.create(null, NullType)), null, create_row(null))
446+
}
447+
448+
test("number format") {
449+
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
450+
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
451+
checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
452+
checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
453+
checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
454+
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
455+
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
456+
checkEvaluation(
457+
FormatNumber(
458+
Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
459+
"15,159,339,180,002,773.2778")
460+
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
461+
checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
462+
}
429463
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,20 +1685,42 @@ object functions {
16851685
//////////////////////////////////////////////////////////////////////////////////////////////
16861686

16871687
/**
1688-
* Computes the length of a given string value.
1688+
* Computes the length of a given string / binary value
16891689
*
16901690
* @group string_funcs
16911691
* @since 1.5.0
16921692
*/
1693-
def strlen(e: Column): Column = StringLength(e.expr)
1693+
def length(e: Column): Column = Length(e.expr)
16941694

16951695
/**
1696-
* Computes the length of a given string column.
1696+
* Computes the length of a given string / binary column
16971697
*
16981698
* @group string_funcs
16991699
* @since 1.5.0
17001700
*/
1701-
def strlen(columnName: String): Column = strlen(Column(columnName))
1701+
def length(columnName: String): Column = length(Column(columnName))
1702+
1703+
/**
1704+
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
1705+
* and returns the result as a string. If D is 0, the result has no decimal point or
1706+
* fractional part.
1707+
*
1708+
* @group string_funcs
1709+
* @since 1.5.0
1710+
*/
1711+
def formatNumber(x: Column, d: Column): Column = FormatNumber(x.expr, d.expr)
1712+
1713+
/**
1714+
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
1715+
* and returns the result as a string. If D is 0, the result has no decimal point or
1716+
* fractional part.
1717+
*
1718+
* @group string_funcs
1719+
* @since 1.5.0
1720+
*/
1721+
def formatNumber(columnXName: String, columnDName: String): Column = {
1722+
formatNumber(Column(columnXName), Column(columnDName))
1723+
}
17021724

17031725
/**
17041726
* Computes the Levenshtein distance of the two given strings.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
208208
Row(2743272264L, 2180413220L))
209209
}
210210

211-
test("string length function") {
212-
val df = Seq(("abc", "")).toDF("a", "b")
213-
checkAnswer(
214-
df.select(strlen($"a"), strlen("b")),
215-
Row(3, 0))
216-
217-
checkAnswer(
218-
df.selectExpr("length(a)", "length(b)"),
219-
Row(3, 0))
220-
}
221-
222211
test("Levenshtein distance") {
223212
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
224213
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
@@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest {
433422
val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
434423
checkAnswer(
435424
doubleData.select(pmod('a, 'b)),
436-
Seq(Row(3.1000000000000005)) // same as hive
425+
Seq(Row(3.1000000000000005)) // same as hive
437426
)
438427
checkAnswer(
439428
doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
440429
Seq(Row(2))
441430
)
442431
}
432+
433+
test("string / binary length function") {
434+
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
435+
checkAnswer(
436+
df.select(length($"a"), length("a"), length($"b"), length("b")),
437+
Row(3, 3, 4, 4))
438+
439+
checkAnswer(
440+
df.selectExpr("length(a)", "length(b)"),
441+
Row(3, 4))
442+
443+
intercept[AnalysisException] {
444+
checkAnswer(
445+
df.selectExpr("length(c)"), // int type of the argument is unacceptable
446+
Row("5.0000"))
447+
}
448+
}
449+
450+
test("number format function") {
451+
val tuple =
452+
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
453+
3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
454+
val df =
455+
Seq(tuple)
456+
.toDF(
457+
"a", // string "aa"
458+
"b", // byte 1
459+
"c", // short 2
460+
"d", // float 3.13223f
461+
"e", // integer 4
462+
"f", // long 5L
463+
"g", // double 6.48173d
464+
"h") // decimal 7.128381
465+
466+
checkAnswer(
467+
df.select(
468+
formatNumber($"f", $"e"),
469+
formatNumber("f", "e")),
470+
Row("5.0000", "5.0000"))
471+
472+
checkAnswer(
473+
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
474+
Row("1.0000"))
475+
476+
checkAnswer(
477+
df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
478+
Row("2.0000"))
479+
480+
checkAnswer(
481+
df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
482+
Row("3.1322"))
483+
484+
checkAnswer(
485+
df.selectExpr("format_number(e, e)"), // not convert anything
486+
Row("4.0000"))
487+
488+
checkAnswer(
489+
df.selectExpr("format_number(f, e)"), // not convert anything
490+
Row("5.0000"))
491+
492+
checkAnswer(
493+
df.selectExpr("format_number(g, e)"), // not convert anything
494+
Row("6.4817"))
495+
496+
checkAnswer(
497+
df.selectExpr("format_number(h, e)"), // not convert anything
498+
Row("7.1284"))
499+
500+
intercept[AnalysisException] {
501+
checkAnswer(
502+
df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
503+
Row("5.0000"))
504+
}
505+
506+
intercept[AnalysisException] {
507+
checkAnswer(
508+
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
509+
Row("5.0000"))
510+
}
511+
}
443512
}

0 commit comments

Comments
 (0)