Skip to content

[SPARK-6888][SQL] Export driver quirks #5498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources._
private[sql] object JDBCRDD extends Logging {
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
* the DriverQuirks class corresponding to your database driver returns null.
* the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
Expand All @@ -40,7 +40,7 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => LongType
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers.
case java.sql.Types.BIT => BooleanType // Per JDBC; JdbcDialect handles quirky drivers.
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
Expand Down Expand Up @@ -92,7 +92,7 @@ private[sql] object JDBCRDD extends Logging {
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val quirks = DriverQuirks.get(url)
val dialect = JdbcDialects.get(url)
val conn: Connection = DriverManager.getConnection(url, properties)
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
Expand All @@ -108,7 +108,7 @@ private[sql] object JDBCRDD extends Logging {
val fieldSize = rsmd.getPrecision(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
var columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata)
if (columnType == null) columnType = getCatalystType(dataType)
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
Expand Down
201 changes: 201 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc

import org.apache.spark.sql.types._
import org.apache.spark.annotation.DeveloperApi

import java.sql.Types


/**
* :: DeveloperApi ::
* Encapsulates everything (extensions, workarounds, quirks) to handle the
* SQL dialect of a certain database or jdbc driver.
* Lots of databases define types that aren't explicitly supported
* by the JDBC spec. Some JDBC drivers also report inaccurate
* information---for instance, BIT(n>1) being reported as a BIT type is quite
* common, even though BIT in JDBC is meant for single-bit values. Also, there
* does not appear to be a standard name for an unbounded string or binary
* type; we use BLOB and CLOB by default but override with database-specific
* alternatives when these are absent or do not behave correctly.
*
* Currently, the only thing done by the dialect is type mapping.
* `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
* is used when writing to a JDBC table. If `getCatalystType` returns `null`,
* the default type handling is used for the given JDBC type. Similarly,
* if `getJDBCType` returns `(null, None)`, the default type handling is used
* for the given Catalyst type.
*/
@DeveloperApi
abstract class JdbcDialect {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
* @return True if the dialect can be applied on the given jdbc url.
* @throws NullPointerException if the url is null.
*/
def canHandle(url : String): Boolean

/**
* Get the custom datatype mapping for the given jdbc meta information.
* @param sqlType The sql type (see java.sql.Types)
* @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
* @param size The size of the type.
* @param md Result metadata associated with this type.
* @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
* or null if the default type mapping should be used.
*/
def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = null

/**
* Retrieve the jdbc / sql type for a give datatype.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return A tuple of sql type name and sql type, or {{{(null, None)}}} for no change.
*/
def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
}

/**
* :: DeveloperApi ::
* Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]].
*
* If multiple matching dialects are registered then all matching ones will be
* tried in reverse order. A user-added dialect will thus be applied first,
* overwriting the defaults.
*
* Note that all new dialects are applied to new jdbc DataFrames only. Make
* sure to register your dialects first.
*/
@DeveloperApi
object JdbcDialects {

private var dialects = List[JdbcDialect]()

/**
* Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]].
* Readding an existing dialect will cause a move-to-front.
* @param dialect The new dialect.
*/
def registerDialect(dialect: JdbcDialect) : Unit = {
dialects = dialect :: dialects.filterNot(_ == dialect)
}

/**
* Unregister a dialect. Does nothing if the dialect is not registered.
* @param dialect The jdbc dialect.
*/
def unregisterDialect(dialect : JdbcDialect) : Unit = {
dialects = dialects.filterNot(_ == dialect)
}

registerDialect(MySQLDialect)
registerDialect(PostgresDialect)

/**
* Fetch the JdbcDialect class corresponding to a given database url.
*/
private[sql] def get(url: String): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(url))
matchingDialects.length match {
case 0 => NoopDialect
case 1 => matchingDialects.head
case _ => new AggregatedDialect(matchingDialects)
}
}
}

/**
* :: DeveloperApi ::
* AggregatedDialect can unify multiple dialects into one virtual Dialect.
* Dialects are tried in order, and the first dialect that does not return a
* neutral element will will.
* @param dialects List of dialects.
*/
@DeveloperApi
class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {

require(!dialects.isEmpty)

def canHandle(url : String): Boolean =
dialects.map(_.canHandle(url)).reduce(_ && _)

override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType =
dialects.map(_.getCatalystType(sqlType, typeName, size, md)).collectFirst {
case dataType if dataType != null => dataType
}.orNull

override def getJDBCType(dt: DataType): (String, Option[Int]) =
dialects.map(_.getJDBCType(dt)).collectFirst {
case t @ (typeName,sqlType) if typeName != null || sqlType.isDefined => t
}.getOrElse((null, None))

}

/**
* :: DeveloperApi ::
* NOOP dialect object, always returning the neutral element.
*/
@DeveloperApi
case object NoopDialect extends JdbcDialect {
def canHandle(url : String): Boolean = true
}

/**
* :: DeveloperApi ::
* Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write.
*/
@DeveloperApi
case object PostgresDialect extends JdbcDialect {
def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
BinaryType
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
StringType
} else if (sqlType == Types.OTHER && typeName.equals("inet")) {
StringType
} else null
}

override def getJDBCType(dt: DataType): (String, Option[Int]) = dt match {
case StringType => ("TEXT", Some(java.sql.Types.CHAR))
case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY))
case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN))
case _ => (null, None)
}
}

/**
* :: DeveloperApi ::
* Default mysql dialect to read bit/bitsets correctly.
*/
@DeveloperApi
case object MySQLDialect extends JdbcDialect {
def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
// This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
// byte arrays instead of longs.
md.putLong("binarylong", 1)
LongType
} else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
BooleanType
} else null
}
}
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ package object jdbc {
*/
def schemaString(df: DataFrame, url: String): String = {
val sb = new StringBuilder()
val quirks = DriverQuirks.get(url)
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
var typ: String = quirks.getJDBCType(field.dataType)._1
var typ: String = dialect.getJDBCType(field.dataType)._1
if (typ == null) typ = field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
Expand All @@ -152,9 +152,9 @@ package object jdbc {
* Saves the RDD to the database in a single transaction.
*/
def saveTable(df: DataFrame, url: String, table: String) {
val quirks = DriverQuirks.get(url)
val dialect = JdbcDialects.get(url)
var nullTypes: Array[Int] = df.schema.fields.map(field => {
var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
var nullType: Option[Int] = dialect.getJDBCType(field.dataType)._2
if (nullType.isEmpty) {
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
Expand Down
Loading