Skip to content

Commit 2dfbb5b

Browse files
committed
support date type
1 parent 6f98902 commit 2dfbb5b

File tree

125 files changed

+813
-39
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+813
-39
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
2323
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
@@ -77,8 +77,9 @@ object ScalaReflection {
7777
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
7878
Schema(MapType(schemaFor(keyType).dataType,
7979
valueDataType, valueContainsNull = valueNullable), nullable = true)
80-
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
80+
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
8181
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
82+
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
8283
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
8384
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
8485
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,36 @@ trait HiveTypeCoercion {
220220
case a: BinaryArithmetic if a.right.dataType == StringType =>
221221
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
222222

223+
case p: BinaryPredicate if p.left.dataType == StringType
224+
&& p.right.dataType == DateType =>
225+
p.makeCopy(Array(Cast(p.left, DateType), p.right))
226+
case p: BinaryPredicate if p.left.dataType == DateType
227+
&& p.right.dataType == StringType =>
228+
p.makeCopy(Array(p.left, Cast(p.right, DateType)))
223229
case p: BinaryPredicate if p.left.dataType == StringType
224230
&& p.right.dataType == TimestampType =>
225231
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
226232
case p: BinaryPredicate if p.left.dataType == TimestampType
227233
&& p.right.dataType == StringType =>
228234
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
235+
case p: BinaryPredicate if p.left.dataType == TimestampType
236+
&& p.right.dataType == DateType =>
237+
p.makeCopy(Array(Cast(p.left, DateType), p.right))
238+
case p: BinaryPredicate if p.left.dataType == DateType
239+
&& p.right.dataType == TimestampType =>
240+
p.makeCopy(Array(p.left, Cast(p.right, DateType)))
229241

230242
case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
231243
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
232244
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
233245
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
234246

247+
case i @ In(a,b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
248+
i.makeCopy(Array(a,b.map(Cast(_,DateType))))
235249
case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
236250
i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
251+
case i @ In(a,b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
252+
i.makeCopy(Array(a,b.map(Cast(_,DateType))))
237253

238254
case Sum(e) if e.dataType == StringType =>
239255
Sum(Cast(e, DoubleType))
@@ -283,6 +299,8 @@ trait HiveTypeCoercion {
283299
// Skip if the type is boolean type already. Note that this extra cast should be removed
284300
// by optimizer.SimplifyCasts.
285301
case Cast(e, BooleanType) if e.dataType == BooleanType => e
302+
// DateType should be null if be cast to boolean.
303+
case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
286304
// If the data type is not boolean and is being cast boolean, turn it into a comparison
287305
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
288306
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121

2222
import scala.language.implicitConversions
2323

@@ -119,6 +119,7 @@ package object dsl {
119119
implicit def floatToLiteral(f: Float) = Literal(f)
120120
implicit def doubleToLiteral(d: Double) = Literal(d)
121121
implicit def stringToLiteral(s: String) = Literal(s)
122+
implicit def dateToLiteral(d: Date) = Literal(d)
122123
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
123124
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
124125
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
@@ -174,6 +175,9 @@ package object dsl {
174175
/** Creates a new AttributeReference of type string */
175176
def string = AttributeReference(s, StringType, nullable = true)()
176177

178+
/** Creates a new AttributeReference of type date */
179+
def date = AttributeReference(s, DateType, nullable = true)()
180+
177181
/** Creates a new AttributeReference of type decimal */
178182
def decimal = AttributeReference(s, DecimalType, nullable = true)()
179183

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121
import java.text.{DateFormat, SimpleDateFormat}
2222

23+
import org.apache.spark.Logging
2324
import org.apache.spark.sql.catalyst.types._
2425

2526
/** Cast the child expression to the target data type. */
26-
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
27+
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
2728
override def foldable = child.foldable
2829

2930
override def nullable = (child.dataType, dataType) match {
3031
case (StringType, _: NumericType) => true
3132
case (StringType, TimestampType) => true
33+
case (StringType, DateType) => true
3234
case _ => child.nullable
3335
}
3436

@@ -42,6 +44,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
4244
// UDFToString
4345
private[this] def castToString: Any => Any = child.dataType match {
4446
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
47+
case DateType => buildCast[Date](_, dateToString)
4548
case TimestampType => buildCast[Timestamp](_, timestampToString)
4649
case _ => buildCast[Any](_, _.toString)
4750
}
@@ -56,7 +59,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
5659
case StringType =>
5760
buildCast[String](_, _.length() != 0)
5861
case TimestampType =>
59-
buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
62+
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
63+
case DateType =>
64+
buildCast[Date](_, d => null)
6065
case LongType =>
6166
buildCast[Long](_, _ != 0)
6267
case IntegerType =>
@@ -95,6 +100,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
95100
buildCast[Short](_, s => new Timestamp(s))
96101
case ByteType =>
97102
buildCast[Byte](_, b => new Timestamp(b))
103+
case DateType =>
104+
buildCast[Date](_, d => Timestamp.valueOf(dateToString(d) + " 00:00:00"))
98105
// TimestampWritable.decimalToTimestamp
99106
case DecimalType =>
100107
buildCast[BigDecimal](_, d => decimalToTimestamp(d))
@@ -130,7 +137,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
130137
// Converts Timestamp to string according to Hive TimestampWritable convention
131138
private[this] def timestampToString(ts: Timestamp): String = {
132139
val timestampString = ts.toString
133-
val formatted = Cast.threadLocalDateFormat.get.format(ts)
140+
val formatted = Cast.threadLocalTimestampFormat.get.format(ts)
134141

135142
if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
136143
formatted + timestampString.substring(19)
@@ -139,13 +146,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
139146
}
140147
}
141148

149+
// Converts Timestamp to string according to Hive TimestampWritable convention
150+
private[this] def timestampToDateString(ts: Timestamp): String = {
151+
Cast.threadLocalDateFormat.get.format(ts)
152+
}
153+
154+
// DateConverter
155+
private[this] def castToDate: Any => Any = child.dataType match {
156+
case StringType =>
157+
buildCast[String](_, s => if (s.contains(" ")) {
158+
try castToDate(castToTimestamp(s))
159+
catch { case _: java.lang.IllegalArgumentException => null }
160+
} else {
161+
try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
162+
})
163+
case TimestampType =>
164+
buildCast[Timestamp](_, t => Date.valueOf(timestampToDateString(t)))
165+
// TimestampWritable.decimalToDate
166+
case _ =>
167+
_ => null
168+
}
169+
170+
// Date cannot be cast to long, according to hive
171+
private[this] def dateToLong(d: Date) = null
172+
173+
// Date cannot be cast to double, according to hive
174+
private[this] def dateToDouble(d: Date) = null
175+
176+
// Converts Timestamp to string according to Hive TimestampWritable convention
177+
private[this] def dateToString(d: Date): String = {
178+
Cast.threadLocalDateFormat.get.format(d)
179+
}
180+
181+
// LongConverter
142182
private[this] def castToLong: Any => Any = child.dataType match {
143183
case StringType =>
144184
buildCast[String](_, s => try s.toLong catch {
145185
case _: NumberFormatException => null
146186
})
147187
case BooleanType =>
148188
buildCast[Boolean](_, b => if (b) 1L else 0L)
189+
case DateType =>
190+
buildCast[Date](_, d => dateToLong(d))
149191
case TimestampType =>
150192
buildCast[Timestamp](_, t => timestampToLong(t))
151193
case DecimalType =>
@@ -154,13 +196,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
154196
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
155197
}
156198

199+
// IntConverter
157200
private[this] def castToInt: Any => Any = child.dataType match {
158201
case StringType =>
159202
buildCast[String](_, s => try s.toInt catch {
160203
case _: NumberFormatException => null
161204
})
162205
case BooleanType =>
163206
buildCast[Boolean](_, b => if (b) 1 else 0)
207+
case DateType =>
208+
buildCast[Date](_, d => dateToLong(d))
164209
case TimestampType =>
165210
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
166211
case DecimalType =>
@@ -169,13 +214,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
169214
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
170215
}
171216

217+
// ShortConverter
172218
private[this] def castToShort: Any => Any = child.dataType match {
173219
case StringType =>
174220
buildCast[String](_, s => try s.toShort catch {
175221
case _: NumberFormatException => null
176222
})
177223
case BooleanType =>
178224
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
225+
case DateType =>
226+
buildCast[Date](_, d => dateToLong(d))
179227
case TimestampType =>
180228
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
181229
case DecimalType =>
@@ -184,13 +232,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
184232
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
185233
}
186234

235+
// ByteConverter
187236
private[this] def castToByte: Any => Any = child.dataType match {
188237
case StringType =>
189238
buildCast[String](_, s => try s.toByte catch {
190239
case _: NumberFormatException => null
191240
})
192241
case BooleanType =>
193242
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
243+
case DateType =>
244+
buildCast[Date](_, d => dateToLong(d))
194245
case TimestampType =>
195246
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
196247
case DecimalType =>
@@ -199,27 +250,33 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
199250
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
200251
}
201252

253+
// DecimalConverter
202254
private[this] def castToDecimal: Any => Any = child.dataType match {
203255
case StringType =>
204256
buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
205257
case _: NumberFormatException => null
206258
})
207259
case BooleanType =>
208260
buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
261+
case DateType =>
262+
buildCast[Date](_, d => dateToDouble(d))
209263
case TimestampType =>
210264
// Note that we lose precision here.
211265
buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
212266
case x: NumericType =>
213267
b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
214268
}
215269

270+
// DoubleConverter
216271
private[this] def castToDouble: Any => Any = child.dataType match {
217272
case StringType =>
218273
buildCast[String](_, s => try s.toDouble catch {
219274
case _: NumberFormatException => null
220275
})
221276
case BooleanType =>
222277
buildCast[Boolean](_, b => if (b) 1d else 0d)
278+
case DateType =>
279+
buildCast[Date](_, d => dateToDouble(d))
223280
case TimestampType =>
224281
buildCast[Timestamp](_, t => timestampToDouble(t))
225282
case DecimalType =>
@@ -228,13 +285,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
228285
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
229286
}
230287

288+
// FloatConverter
231289
private[this] def castToFloat: Any => Any = child.dataType match {
232290
case StringType =>
233291
buildCast[String](_, s => try s.toFloat catch {
234292
case _: NumberFormatException => null
235293
})
236294
case BooleanType =>
237295
buildCast[Boolean](_, b => if (b) 1f else 0f)
296+
case DateType =>
297+
buildCast[Date](_, d => dateToDouble(d))
238298
case TimestampType =>
239299
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
240300
case DecimalType =>
@@ -245,17 +305,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
245305

246306
private[this] lazy val cast: Any => Any = dataType match {
247307
case dt if dt == child.dataType => identity[Any]
248-
case StringType => castToString
249-
case BinaryType => castToBinary
250-
case DecimalType => castToDecimal
308+
case StringType => castToString
309+
case BinaryType => castToBinary
310+
case DecimalType => castToDecimal
311+
case DateType => castToDate
251312
case TimestampType => castToTimestamp
252-
case BooleanType => castToBoolean
253-
case ByteType => castToByte
254-
case ShortType => castToShort
255-
case IntegerType => castToInt
256-
case FloatType => castToFloat
257-
case LongType => castToLong
258-
case DoubleType => castToDouble
313+
case BooleanType => castToBoolean
314+
case ByteType => castToByte
315+
case ShortType => castToShort
316+
case IntegerType => castToInt
317+
case FloatType => castToFloat
318+
case LongType => castToLong
319+
case DoubleType => castToDouble
259320
}
260321

261322
override def eval(input: Row): Any = {
@@ -267,6 +328,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
267328
object Cast {
268329
// `SimpleDateFormat` is not thread-safe.
269330
private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] {
331+
override def initialValue() = {
332+
new SimpleDateFormat("yyyy-MM-dd")
333+
}
334+
}
335+
336+
// `SimpleDateFormat` is not thread-safe.
337+
private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
270338
override def initialValue() = {
271339
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
272340
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.sql.catalyst.types._
2323

@@ -33,6 +33,7 @@ object Literal {
3333
case b: Boolean => Literal(b, BooleanType)
3434
case d: BigDecimal => Literal(d, DecimalType)
3535
case t: Timestamp => Literal(t, TimestampType)
36+
case d: Date => Literal(d, DateType)
3637
case a: Array[Byte] => Literal(a, BinaryType)
3738
case null => Literal(null, NullType)
3839
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.types
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121

2222
import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
2323
import scala.reflect.ClassTag
@@ -250,6 +250,18 @@ case object TimestampType extends NativeType {
250250
}
251251
}
252252

253+
case object DateType extends NativeType {
254+
private[sql] type JvmType = Date
255+
256+
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
257+
258+
private[sql] val ordering = new Ordering[JvmType] {
259+
def compare(x: Date, y: Date) = x.compareTo(y)
260+
}
261+
262+
def simpleString: String = "date"
263+
}
264+
253265
abstract class NumericType extends NativeType with PrimitiveType {
254266
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
255267
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a

0 commit comments

Comments
 (0)