Skip to content

Commit 5217f7b

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-26248][SQL] Infer date type from CSV
## What changes were proposed in this pull request? The `CSVInferSchema` class is extended to support inferring of `DateType` from CSV input. The attempt to infer `DateType` is performed after inferring `TimestampType`. ## How was this patch tested? Added new test for inferring date types from CSV . It was also tested by existing suites like `CSVInferSchemaSuite`, `CsvExpressionsSuite`, `CsvFunctionsSuite` and `CsvSuite`. Closes #23202 from MaxGekk/csv-date-inferring. Lead-authored-by: Maxim Gekk <[email protected]> Co-authored-by: Maxim Gekk <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e3e33d8 commit 5217f7b

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,20 @@ import scala.util.control.Exception.allCatch
2222
import org.apache.spark.rdd.RDD
2323
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
2424
import org.apache.spark.sql.catalyst.expressions.ExprUtils
25-
import org.apache.spark.sql.catalyst.util.TimestampFormatter
25+
import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter}
2626
import org.apache.spark.sql.types._
2727

2828
class CSVInferSchema(val options: CSVOptions) extends Serializable {
2929

3030
@transient
31-
private lazy val timestampParser = TimestampFormatter(
31+
private lazy val timestampFormatter = TimestampFormatter(
3232
options.timestampFormat,
3333
options.timeZone,
3434
options.locale)
35+
@transient
36+
private lazy val dateFormatter = DateFormatter(
37+
options.dateFormat,
38+
options.locale)
3539

3640
private val decimalParser = {
3741
ExprUtils.getDecimalParser(options.locale)
@@ -104,6 +108,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
104108
compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType)
105109
case DoubleType => tryParseDouble(field)
106110
case TimestampType => tryParseTimestamp(field)
111+
case DateType => tryParseDate(field)
107112
case BooleanType => tryParseBoolean(field)
108113
case StringType => StringType
109114
case other: DataType =>
@@ -159,9 +164,16 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
159164
}
160165

161166
private def tryParseTimestamp(field: String): DataType = {
162-
// This case infers a custom `dataFormat` is set.
163-
if ((allCatch opt timestampParser.parse(field)).isDefined) {
167+
if ((allCatch opt timestampFormatter.parse(field)).isDefined) {
164168
TimestampType
169+
} else {
170+
tryParseDate(field)
171+
}
172+
}
173+
174+
private def tryParseDate(field: String): DataType = {
175+
if ((allCatch opt dateFormatter.parse(field)).isDefined) {
176+
DateType
165177
} else {
166178
tryParseBoolean(field)
167179
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,22 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
187187

188188
Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0)))
189189
}
190+
191+
test("inferring date type") {
192+
var options = new CSVOptions(Map("dateFormat" -> "yyyy/MM/dd"), false, "GMT")
193+
var inferSchema = new CSVInferSchema(options)
194+
assert(inferSchema.inferField(NullType, "2018/12/02") == DateType)
195+
196+
options = new CSVOptions(Map("dateFormat" -> "MMM yyyy"), false, "GMT")
197+
inferSchema = new CSVInferSchema(options)
198+
assert(inferSchema.inferField(NullType, "Dec 2018") == DateType)
199+
200+
options = new CSVOptions(
201+
Map("dateFormat" -> "yyyy-MM-dd", "timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"),
202+
columnPruning = false,
203+
defaultTimeZoneId = "GMT")
204+
inferSchema = new CSVInferSchema(options)
205+
assert(inferSchema.inferField(NullType, "2018-12-03T11:00:00") == TimestampType)
206+
assert(inferSchema.inferField(NullType, "2018-12-03") == DateType)
207+
}
190208
}

0 commit comments

Comments
 (0)