Skip to content

[SPARK-5009] [SQL] Long keyword support in SQL Parsers #3926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,42 @@ import scala.util.parsing.input.CharArrayReader.EofCh

import org.apache.spark.sql.catalyst.plans.logical._

private[sql] object KeywordNormalizer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of a nit, but since this is only used in AbstractSparkSQLParser and its subclasses I'd just make it a protected method to avoid the syntatic overhead of a whole separate object. I believe you are doing further refactoring so maybe that can be done in a followup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also used withinSqlLexical.processIdent, but you're right, we'd better keep it the minimize visibility. I will do that in #4015 .

def apply(str: String) = str.toLowerCase()
}

private[sql] abstract class AbstractSparkSQLParser
extends StandardTokenParsers with PackratParsers {

def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match {
case Success(plan, _) => plan
case failureOrError => sys.error(failureOrError.toString)
def apply(input: String): LogicalPlan = {
// Initialize the Keywords.
lexical.initialize(reservedWords)
phrase(start)(new lexical.Scanner(input)) match {
case Success(plan, _) => plan
case failureOrError => sys.error(failureOrError.toString)
}
}

protected case class Keyword(str: String)
protected case class Keyword(str: String) {
def normalize = KeywordNormalizer(str)
def parser: Parser[String] = normalize
}

protected implicit def asParser(k: Keyword): Parser[String] = k.parser

// By default, use Reflection to find the reserved words defined in the sub class.
// NOTICE, Since the Keyword properties defined by sub class, we couldn't call this
// method during the parent class instantiation, because the sub class instance
// isn't created yet.
protected lazy val reservedWords: Seq[String] =
this
.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].normalize)

// Set the keywords as empty by default, will change that later.
override val lexical = new SqlLexical

protected def start: Parser[LogicalPlan]

Expand All @@ -52,18 +79,27 @@ private[sql] abstract class AbstractSparkSQLParser
}
}

class SqlLexical(val keywords: Seq[String]) extends StdLexical {
class SqlLexical extends StdLexical {
case class FloatLit(chars: String) extends Token {
override def toString = chars
}

reserved ++= keywords.flatMap(w => allCaseVersions(w))
/* This is a work around to support the lazy setting */
def initialize(keywords: Seq[String]): Unit = {
reserved.clear()
reserved ++= keywords
}

delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
)

protected override def processIdent(name: String) = {
val token = KeywordNormalizer(name)
if (reserved contains token) Keyword(token) else Identifier(name)
}

override lazy val token: Parser[Token] =
( identChar ~ (identChar | digit).* ^^
{ case first ~ rest => processIdent((first :: rest).mkString) }
Expand Down Expand Up @@ -94,14 +130,5 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {
| '-' ~ '-' ~ chrExcept(EofCh, '\n').*
| '/' ~ '*' ~ failure("unclosed comment")
).*

/** Generate all variations of upper and lower case of a given string */
def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
if (s.isEmpty) {
Stream(prefix)
} else {
allCaseVersions(s.tail, prefix + s.head.toLower) #:::
allCaseVersions(s.tail, prefix + s.head.toUpper)
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ import org.apache.spark.sql.types._
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
class SqlParser extends AbstractSparkSQLParser {
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ABS = Keyword("ABS")
protected val ALL = Keyword("ALL")
protected val AND = Keyword("AND")
Expand Down Expand Up @@ -107,16 +106,6 @@ class SqlParser extends AbstractSparkSQLParser {
protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")

// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
this
.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)

override val lexical = new SqlLexical(reservedWords)

protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Command
import org.scalatest.FunSuite

private[sql] case class TestCommand(cmd: String) extends Command

private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {
protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST")

override protected lazy val start: Parser[LogicalPlan] = set

private lazy val set: Parser[LogicalPlan] =
EXECUTE ~> ident ^^ {
case fileName => TestCommand(fileName)
}
}

private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
protected val EXECUTE = Keyword("EXECUTE")

override protected lazy val start: Parser[LogicalPlan] = set

private lazy val set: Parser[LogicalPlan] =
EXECUTE ~> ident ^^ {
case fileName => TestCommand(fileName)
}
}

class SqlParserSuite extends FunSuite {

test("test long keyword") {
val parser = new SuperLongKeywordTestParser
assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand"))
}

test("test case insensitive") {
val parser = new CaseInsensitiveTestParser
assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand"))
assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand"))
assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

protected[sql] def parseSql(sql: String): LogicalPlan = {
ddlParser(sql).getOrElse(sqlParser(sql))
ddlParser(sql, false).getOrElse(sqlParser(sql))
}

protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
Expand Down
15 changes: 2 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql


import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser}
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand}
Expand Down Expand Up @@ -61,18 +62,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr
protected val TABLE = Keyword("TABLE")
protected val UNCACHE = Keyword("UNCACHE")

protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

private val reservedWords: Seq[String] =
this
.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)

override val lexical = new SqlLexical(reservedWords)

override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others

private lazy val cache: Parser[LogicalPlan] =
Expand Down
39 changes: 15 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,42 @@
package org.apache.spark.sql.sources

import scala.language.implicitConversions
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.PackratParsers

import org.apache.spark.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SqlLexical
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
* A parser for foreign DDL commands.
*/
private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging {

def apply(input: String): Option[LogicalPlan] = {
phrase(ddl)(new lexical.Scanner(input)) match {
case Success(r, x) => Some(r)
case x =>
logDebug(s"Not recognized as DDL: $x")
None
private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {

def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
try {
Some(apply(input))
} catch {
case _ if !exceptionOnError => None
case x: Throwable => throw x
}
}

def parseType(input: String): DataType = {
lexical.initialize(reservedWords)
phrase(dataType)(new lexical.Scanner(input)) match {
case Success(r, x) => r
case x =>
sys.error(s"Unsupported dataType: $x")
}
}

protected case class Keyword(str: String)

protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val CREATE = Keyword("CREATE")
protected val TEMPORARY = Keyword("TEMPORARY")
protected val TABLE = Keyword("TABLE")
Expand All @@ -80,17 +78,10 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected val MAP = Keyword("MAP")
protected val STRUCT = Keyword("STRUCT")

// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
this.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)

override val lexical = new SqlLexical(reservedWords)

protected lazy val ddl: Parser[LogicalPlan] = createTable

protected def start: Parser[LogicalPlan] = ddl

/**
* `CREATE [TEMPORARY] TABLE avroTable
* USING org.apache.spark.sql.avro
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,20 @@ package org.apache.spark.sql.hive
import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical}
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand}

/**
* A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions.
*/
private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ADD = Keyword("ADD")
protected val DFS = Keyword("DFS")
protected val FILE = Keyword("FILE")
protected val JAR = Keyword("JAR")

private val reservedWords =
this
.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)

override val lexical = new SqlLexical(reservedWords)

protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl

protected lazy val hiveQl: Parser[LogicalPlan] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
if (conf.dialect == "sql") {
super.sql(sqlText)
} else if (conf.dialect == "hiveql") {
new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText)))
new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
Expand Down