Skip to content

Commit 42dea3a

Browse files
chenghao-intelrxin
authored andcommitted
[SPARK-8245][SQL] FormatNumber/Length Support for Expression
- `BinaryType` for `Length` - `FormatNumber` Author: Cheng Hao <[email protected]> Closes apache#7034 from chenghao-intel/expression and squashes the following commits: e534b87 [Cheng Hao] python api style issue 601bbf5 [Cheng Hao] add python API support 3ebe288 [Cheng Hao] update as feedback 52274f7 [Cheng Hao] add support for udf_format_number and length for binary
1 parent 9c64a75 commit 42dea3a

File tree

6 files changed

+261
-41
lines changed

6 files changed

+261
-41
lines changed

python/pyspark/sql/functions.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
'coalesce',
4040
'countDistinct',
4141
'explode',
42+
'format_number',
43+
'length',
4244
'log2',
4345
'md5',
4446
'monotonicallyIncreasingId',
@@ -47,7 +49,6 @@
4749
'sha1',
4850
'sha2',
4951
'sparkPartitionId',
50-
'strlen',
5152
'struct',
5253
'udf',
5354
'when']
@@ -506,14 +507,28 @@ def sparkPartitionId():
506507

507508
@ignore_unicode_prefix
508509
@since(1.5)
509-
def strlen(col):
510-
"""Calculates the length of a string expression.
510+
def length(col):
511+
"""Calculates the length of a string or binary expression.
511512
512-
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
513+
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
513514
[Row(length=3)]
514515
"""
515516
sc = SparkContext._active_spark_context
516-
return Column(sc._jvm.functions.strlen(_to_java_column(col)))
517+
return Column(sc._jvm.functions.length(_to_java_column(col)))
518+
519+
520+
@ignore_unicode_prefix
521+
@since(1.5)
522+
def format_number(col, d):
523+
"""Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
524+
and returns the result as a string.
525+
:param col: the column name of the numeric value to be formatted
526+
:param d: the N decimal places
527+
>>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
528+
[Row(v=u'5.0000')]
529+
"""
530+
sc = SparkContext._active_spark_context
531+
return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
517532

518533

519534
@ignore_unicode_prefix

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,12 @@ object FunctionRegistry {
152152
expression[Base64]("base64"),
153153
expression[Encode]("encode"),
154154
expression[Decode]("decode"),
155-
expression[StringInstr]("instr"),
155+
expression[FormatNumber]("format_number"),
156156
expression[Lower]("lcase"),
157157
expression[Lower]("lower"),
158-
expression[StringLength]("length"),
158+
expression[Length]("length"),
159159
expression[Levenshtein]("levenshtein"),
160+
expression[StringInstr]("instr"),
160161
expression[StringLocate]("locate"),
161162
expression[StringLPad]("lpad"),
162163
expression[StringTrimLeft]("ltrim"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.text.DecimalFormat
2021
import java.util.Locale
2122
import java.util.regex.Pattern
2223

23-
import org.apache.commons.lang3.StringUtils
24-
2524
import org.apache.spark.sql.catalyst.InternalRow
2625
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2726
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
553552
}
554553

555554
/**
556-
* A function that return the length of the given string expression.
555+
* A function that return the length of the given string or binary expression.
557556
*/
558-
case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
557+
case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
559558
override def dataType: DataType = IntegerType
560-
override def inputTypes: Seq[DataType] = Seq(StringType)
559+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
561560

562-
protected override def nullSafeEval(string: Any): Any =
563-
string.asInstanceOf[UTF8String].numChars
561+
protected override def nullSafeEval(value: Any): Any = child.dataType match {
562+
case StringType => value.asInstanceOf[UTF8String].numChars
563+
case BinaryType => value.asInstanceOf[Array[Byte]].length
564+
}
564565

565566
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
566-
defineCodeGen(ctx, ev, c => s"($c).numChars()")
567+
child.dataType match {
568+
case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
569+
case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
570+
}
567571
}
568572

569573
override def prettyName: String = "length"
@@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression)
668672
}
669673
}
670674

675+
/**
676+
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
677+
* and returns the result as a string. If D is 0, the result has no decimal point or
678+
* fractional part.
679+
*/
680+
case class FormatNumber(x: Expression, d: Expression)
681+
extends BinaryExpression with ExpectsInputTypes {
682+
683+
override def left: Expression = x
684+
override def right: Expression = d
685+
override def dataType: DataType = StringType
686+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
687+
688+
// Associated with the pattern, for the last d value, and we will update the
689+
// pattern (DecimalFormat) once the new coming d value differ with the last one.
690+
@transient
691+
private var lastDValue: Int = -100
692+
693+
// A cached DecimalFormat, for performance concern, we will change it
694+
// only if the d value changed.
695+
@transient
696+
private val pattern: StringBuffer = new StringBuffer()
697+
698+
@transient
699+
private val numberFormat: DecimalFormat = new DecimalFormat("")
700+
701+
override def eval(input: InternalRow): Any = {
702+
val xObject = x.eval(input)
703+
if (xObject == null) {
704+
return null
705+
}
706+
707+
val dObject = d.eval(input)
708+
709+
if (dObject == null || dObject.asInstanceOf[Int] < 0) {
710+
return null
711+
}
712+
val dValue = dObject.asInstanceOf[Int]
713+
714+
if (dValue != lastDValue) {
715+
// construct a new DecimalFormat only if a new dValue
716+
pattern.delete(0, pattern.length())
717+
pattern.append("#,###,###,###,###,###,##0")
718+
719+
// decimal place
720+
if (dValue > 0) {
721+
pattern.append(".")
722+
723+
var i = 0
724+
while (i < dValue) {
725+
i += 1
726+
pattern.append("0")
727+
}
728+
}
729+
val dFormat = new DecimalFormat(pattern.toString())
730+
lastDValue = dValue;
731+
numberFormat.applyPattern(dFormat.toPattern())
732+
}
733+
734+
x.dataType match {
735+
case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
736+
case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
737+
case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
738+
case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
739+
case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
740+
case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
741+
case _: DecimalType =>
742+
UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
743+
}
744+
}
745+
746+
override def prettyName: String = "format_number"
747+
}
748+

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
22-
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
22+
import org.apache.spark.sql.types._
2323

2424

2525
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
216216
}
217217
}
218218

219-
test("length for string") {
220-
val a = 'a.string.at(0)
221-
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
222-
checkEvaluation(StringLength(a), 5, create_row("abdef"))
223-
checkEvaluation(StringLength(a), 0, create_row(""))
224-
checkEvaluation(StringLength(a), null, create_row(null))
225-
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
226-
}
227-
228219
test("ascii for string") {
229220
val a = 'a.string.at(0)
230221
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
@@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
426417
checkEvaluation(
427418
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
428419
}
420+
421+
test("length for string / binary") {
422+
val a = 'a.string.at(0)
423+
val b = 'b.binary.at(0)
424+
val bytes = Array[Byte](1, 2, 3, 1, 2)
425+
val string = "abdef"
426+
427+
// scalastyle:off
428+
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
429+
checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
430+
// scalastyle:on
431+
checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))
432+
433+
checkEvaluation(Length(a), 5, create_row(string))
434+
checkEvaluation(Length(b), 5, create_row(bytes))
435+
436+
checkEvaluation(Length(a), 0, create_row(""))
437+
checkEvaluation(Length(b), 0, create_row(Array[Byte]()))
438+
439+
checkEvaluation(Length(a), null, create_row(null))
440+
checkEvaluation(Length(b), null, create_row(null))
441+
442+
checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
443+
checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
444+
}
445+
446+
test("number format") {
447+
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
448+
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
449+
checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
450+
checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
451+
checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
452+
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
453+
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
454+
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null)
455+
checkEvaluation(
456+
FormatNumber(
457+
Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
458+
"15,159,339,180,002,773.2778")
459+
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
460+
checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
461+
}
429462
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,20 +1685,44 @@ object functions {
16851685
//////////////////////////////////////////////////////////////////////////////////////////////
16861686

16871687
/**
1688-
* Computes the length of a given string value.
1688+
* Computes the length of a given string / binary value.
16891689
*
16901690
* @group string_funcs
16911691
* @since 1.5.0
16921692
*/
1693-
def strlen(e: Column): Column = StringLength(e.expr)
1693+
def length(e: Column): Column = Length(e.expr)
16941694

16951695
/**
1696-
* Computes the length of a given string column.
1696+
* Computes the length of a given string / binary column.
16971697
*
16981698
* @group string_funcs
16991699
* @since 1.5.0
17001700
*/
1701-
def strlen(columnName: String): Column = strlen(Column(columnName))
1701+
def length(columnName: String): Column = length(Column(columnName))
1702+
1703+
/**
1704+
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
1705+
* and returns the result as a string.
1706+
* If d is 0, the result has no decimal point or fractional part.
1707+
* If d < 0, the result will be null.
1708+
*
1709+
* @group string_funcs
1710+
* @since 1.5.0
1711+
*/
1712+
def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
1713+
1714+
/**
1715+
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
1716+
* and returns the result as a string.
1717+
* If d is 0, the result has no decimal point or fractional part.
1718+
* If d < 0, the result will be null.
1719+
*
1720+
* @group string_funcs
1721+
* @since 1.5.0
1722+
*/
1723+
def format_number(columnXName: String, d: Int): Column = {
1724+
format_number(Column(columnXName), d)
1725+
}
17021726

17031727
/**
17041728
* Computes the Levenshtein distance of the two given strings.

0 commit comments

Comments
 (0)