17
17
18
18
package org .apache .spark .sql .catalyst .expressions
19
19
20
- import java .sql .Timestamp
20
+ import java .sql .{ Date , Timestamp }
21
21
import java .text .{DateFormat , SimpleDateFormat }
22
22
23
+ import org .apache .spark .Logging
23
24
import org .apache .spark .sql .catalyst .types ._
24
25
25
26
/** Cast the child expression to the target data type. */
26
- case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression {
27
+ case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression with Logging {
27
28
override def foldable = child.foldable
28
29
29
30
override def nullable = (child.dataType, dataType) match {
30
31
case (StringType , _ : NumericType ) => true
31
32
case (StringType , TimestampType ) => true
33
+ case (StringType , DateType ) => true
32
34
case _ => child.nullable
33
35
}
34
36
@@ -42,6 +44,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
42
44
// UDFToString
43
45
private [this ] def castToString : Any => Any = child.dataType match {
44
46
case BinaryType => buildCast[Array [Byte ]](_, new String (_, " UTF-8" ))
47
+ case DateType => buildCast[Date ](_, dateToString)
45
48
case TimestampType => buildCast[Timestamp ](_, timestampToString)
46
49
case _ => buildCast[Any ](_, _.toString)
47
50
}
@@ -56,7 +59,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
56
59
case StringType =>
57
60
buildCast[String ](_, _.length() != 0 )
58
61
case TimestampType =>
59
- buildCast[Timestamp ](_, b => b.getTime() != 0 || b.getNanos() != 0 )
62
+ buildCast[Timestamp ](_, t => t.getTime() != 0 || t.getNanos() != 0 )
63
+ case DateType =>
64
+ buildCast[Date ](_, d => null )
60
65
case LongType =>
61
66
buildCast[Long ](_, _ != 0 )
62
67
case IntegerType =>
@@ -95,6 +100,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
95
100
buildCast[Short ](_, s => new Timestamp (s))
96
101
case ByteType =>
97
102
buildCast[Byte ](_, b => new Timestamp (b))
103
+ case DateType =>
104
+ buildCast[Date ](_, d => Timestamp .valueOf(dateToString(d) + " 00:00:00" ))
98
105
// TimestampWritable.decimalToTimestamp
99
106
case DecimalType =>
100
107
buildCast[BigDecimal ](_, d => decimalToTimestamp(d))
@@ -130,7 +137,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
130
137
// Converts Timestamp to string according to Hive TimestampWritable convention
131
138
private [this ] def timestampToString (ts : Timestamp ): String = {
132
139
val timestampString = ts.toString
133
- val formatted = Cast .threadLocalDateFormat .get.format(ts)
140
+ val formatted = Cast .threadLocalTimestampFormat .get.format(ts)
134
141
135
142
if (timestampString.length > 19 && timestampString.substring(19 ) != " .0" ) {
136
143
formatted + timestampString.substring(19 )
@@ -139,13 +146,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
139
146
}
140
147
}
141
148
149
+ // Converts Timestamp to string according to Hive TimestampWritable convention
150
+ private [this ] def timestampToDateString (ts : Timestamp ): String = {
151
+ Cast .threadLocalDateFormat.get.format(ts)
152
+ }
153
+
154
+ // DateConverter
155
+ private [this ] def castToDate : Any => Any = child.dataType match {
156
+ case StringType =>
157
+ buildCast[String ](_, s => if (s.contains(" " )) {
158
+ try castToDate(castToTimestamp(s))
159
+ catch { case _ : java.lang.IllegalArgumentException => null }
160
+ } else {
161
+ try Date .valueOf(s) catch { case _ : java.lang.IllegalArgumentException => null }
162
+ })
163
+ case TimestampType =>
164
+ buildCast[Timestamp ](_, t => Date .valueOf(timestampToDateString(t)))
165
+ // TimestampWritable.decimalToDate
166
+ case _ =>
167
+ _ => null
168
+ }
169
+
170
+ // Date cannot be cast to long, according to hive
171
+ private [this ] def dateToLong (d : Date ) = null
172
+
173
+ // Date cannot be cast to double, according to hive
174
+ private [this ] def dateToDouble (d : Date ) = null
175
+
176
+ // Converts Timestamp to string according to Hive TimestampWritable convention
177
+ private [this ] def dateToString (d : Date ): String = {
178
+ Cast .threadLocalDateFormat.get.format(d)
179
+ }
180
+
181
+ // LongConverter
142
182
private [this ] def castToLong : Any => Any = child.dataType match {
143
183
case StringType =>
144
184
buildCast[String ](_, s => try s.toLong catch {
145
185
case _ : NumberFormatException => null
146
186
})
147
187
case BooleanType =>
148
188
buildCast[Boolean ](_, b => if (b) 1L else 0L )
189
+ case DateType =>
190
+ buildCast[Date ](_, d => dateToLong(d))
149
191
case TimestampType =>
150
192
buildCast[Timestamp ](_, t => timestampToLong(t))
151
193
case DecimalType =>
@@ -154,13 +196,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
154
196
b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
155
197
}
156
198
199
+ // IntConverter
157
200
private [this ] def castToInt : Any => Any = child.dataType match {
158
201
case StringType =>
159
202
buildCast[String ](_, s => try s.toInt catch {
160
203
case _ : NumberFormatException => null
161
204
})
162
205
case BooleanType =>
163
206
buildCast[Boolean ](_, b => if (b) 1 else 0 )
207
+ case DateType =>
208
+ buildCast[Date ](_, d => dateToLong(d))
164
209
case TimestampType =>
165
210
buildCast[Timestamp ](_, t => timestampToLong(t).toInt)
166
211
case DecimalType =>
@@ -169,13 +214,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
169
214
b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
170
215
}
171
216
217
+ // ShortConverter
172
218
private [this ] def castToShort : Any => Any = child.dataType match {
173
219
case StringType =>
174
220
buildCast[String ](_, s => try s.toShort catch {
175
221
case _ : NumberFormatException => null
176
222
})
177
223
case BooleanType =>
178
224
buildCast[Boolean ](_, b => if (b) 1 .toShort else 0 .toShort)
225
+ case DateType =>
226
+ buildCast[Date ](_, d => dateToLong(d))
179
227
case TimestampType =>
180
228
buildCast[Timestamp ](_, t => timestampToLong(t).toShort)
181
229
case DecimalType =>
@@ -184,13 +232,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
184
232
b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
185
233
}
186
234
235
+ // ByteConverter
187
236
private [this ] def castToByte : Any => Any = child.dataType match {
188
237
case StringType =>
189
238
buildCast[String ](_, s => try s.toByte catch {
190
239
case _ : NumberFormatException => null
191
240
})
192
241
case BooleanType =>
193
242
buildCast[Boolean ](_, b => if (b) 1 .toByte else 0 .toByte)
243
+ case DateType =>
244
+ buildCast[Date ](_, d => dateToLong(d))
194
245
case TimestampType =>
195
246
buildCast[Timestamp ](_, t => timestampToLong(t).toByte)
196
247
case DecimalType =>
@@ -199,27 +250,33 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
199
250
b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
200
251
}
201
252
253
+ // DecimalConverter
202
254
private [this ] def castToDecimal : Any => Any = child.dataType match {
203
255
case StringType =>
204
256
buildCast[String ](_, s => try BigDecimal (s.toDouble) catch {
205
257
case _ : NumberFormatException => null
206
258
})
207
259
case BooleanType =>
208
260
buildCast[Boolean ](_, b => if (b) BigDecimal (1 ) else BigDecimal (0 ))
261
+ case DateType =>
262
+ buildCast[Date ](_, d => dateToDouble(d))
209
263
case TimestampType =>
210
264
// Note that we lose precision here.
211
265
buildCast[Timestamp ](_, t => BigDecimal (timestampToDouble(t)))
212
266
case x : NumericType =>
213
267
b => BigDecimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b))
214
268
}
215
269
270
+ // DoubleConverter
216
271
private [this ] def castToDouble : Any => Any = child.dataType match {
217
272
case StringType =>
218
273
buildCast[String ](_, s => try s.toDouble catch {
219
274
case _ : NumberFormatException => null
220
275
})
221
276
case BooleanType =>
222
277
buildCast[Boolean ](_, b => if (b) 1d else 0d )
278
+ case DateType =>
279
+ buildCast[Date ](_, d => dateToDouble(d))
223
280
case TimestampType =>
224
281
buildCast[Timestamp ](_, t => timestampToDouble(t))
225
282
case DecimalType =>
@@ -228,13 +285,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
228
285
b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
229
286
}
230
287
288
+ // FloatConverter
231
289
private [this ] def castToFloat : Any => Any = child.dataType match {
232
290
case StringType =>
233
291
buildCast[String ](_, s => try s.toFloat catch {
234
292
case _ : NumberFormatException => null
235
293
})
236
294
case BooleanType =>
237
295
buildCast[Boolean ](_, b => if (b) 1f else 0f )
296
+ case DateType =>
297
+ buildCast[Date ](_, d => dateToDouble(d))
238
298
case TimestampType =>
239
299
buildCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
240
300
case DecimalType =>
@@ -245,17 +305,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
245
305
246
306
private [this ] lazy val cast : Any => Any = dataType match {
247
307
case dt if dt == child.dataType => identity[Any ]
248
- case StringType => castToString
249
- case BinaryType => castToBinary
250
- case DecimalType => castToDecimal
308
+ case StringType => castToString
309
+ case BinaryType => castToBinary
310
+ case DecimalType => castToDecimal
311
+ case DateType => castToDate
251
312
case TimestampType => castToTimestamp
252
- case BooleanType => castToBoolean
253
- case ByteType => castToByte
254
- case ShortType => castToShort
255
- case IntegerType => castToInt
256
- case FloatType => castToFloat
257
- case LongType => castToLong
258
- case DoubleType => castToDouble
313
+ case BooleanType => castToBoolean
314
+ case ByteType => castToByte
315
+ case ShortType => castToShort
316
+ case IntegerType => castToInt
317
+ case FloatType => castToFloat
318
+ case LongType => castToLong
319
+ case DoubleType => castToDouble
259
320
}
260
321
261
322
override def eval (input : Row ): Any = {
@@ -267,6 +328,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
267
328
object Cast {
268
329
// `SimpleDateFormat` is not thread-safe.
269
330
private [sql] val threadLocalDateFormat = new ThreadLocal [DateFormat ] {
331
+ override def initialValue () = {
332
+ new SimpleDateFormat (" yyyy-MM-dd" )
333
+ }
334
+ }
335
+
336
+ // `SimpleDateFormat` is not thread-safe.
337
+ private [sql] val threadLocalTimestampFormat = new ThreadLocal [DateFormat ] {
270
338
override def initialValue () = {
271
339
new SimpleDateFormat (" yyyy-MM-dd HH:mm:ss" )
272
340
}
0 commit comments