Skip to content

Commit

Permalink
[SPARK-42552][SQL] Correct the two-stage parsing strategy of antlr pa…
Browse files Browse the repository at this point in the history
…rser

### What changes were proposed in this pull request?

This PR follows the antlr/antlr4#192 (comment) to correct the current implementation of the **two-stage parsing strategy** in `AbstractSqlParser`.

### Why are the changes needed?

This should be a long-standing issue, before [SPARK-38385](https://issues.apache.org/jira/browse/SPARK-38385), Spark uses `DefaultErrorStrategy`, and after [SPARK-38385](https://issues.apache.org/jira/browse/SPARK-38385) Spark uses class `SparkParserErrorStrategy() extends DefaultErrorStrategy`. It is not a correct implementation of the "two-stage parsing strategy"

As mentioned in antlr/antlr4#192 (comment)

> You can save a great deal of time on correct inputs by using a two-stage parsing strategy.
>
> 1. Attempt to parse the input using BailErrorStrategy and PredictionMode.SLL.
>    If no exception is thrown, you know the answer is correct.
> 2. If a ParseCancellationException is thrown, retry the parse using the default
>    settings (DefaultErrorStrategy and PredictionMode.LL).

### Does this PR introduce _any_ user-facing change?

Yes, the Spark SQL parser becomes more powerful, SQL like `SELECT 1 UNION SELECT 2` parse succeeded after this change.

### How was this patch tested?

New UT is added.

Closes apache#40835 from pan3793/SPARK-42552.

Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
pan3793 authored and cloud-fan committed Apr 19, 2023
1 parent 56f6af7 commit f8604ad
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,28 @@ abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with
parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.setErrorHandler(new SparkParserErrorStrategy())
parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords
parser.double_quoted_identifiers = conf.doubleQuotedIdentifiers

// https://github.com/antlr/antlr4/issues/192#issuecomment-15238595
// Save a great deal of time on correct inputs by using a two-stage parsing strategy.
try {
try {
// first, try parsing with potentially faster SLL mode
// first, try parsing with potentially faster SLL mode w/ SparkParserBailErrorStrategy
parser.setErrorHandler(new SparkParserBailErrorStrategy())
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
// if we fail, parse with LL mode w/ SparkParserErrorStrategy
tokenStream.seek(0) // rewind input stream
parser.reset()

// Try Again.
parser.setErrorHandler(new SparkParserErrorStrategy())
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.parser

import org.antlr.v4.runtime.{DefaultErrorStrategy, InputMismatchException, IntStream, NoViableAltException, Parser, ParserRuleContext, RecognitionException, Recognizer, Token}
import org.antlr.v4.runtime.misc.ParseCancellationException

/**
* A [[SparkRecognitionException]] extends the [[RecognitionException]] with more information
Expand Down Expand Up @@ -112,3 +113,46 @@ class SparkParserErrorStrategy() extends DefaultErrorStrategy {
}
}
}

/**
* Inspired by [[org.antlr.v4.runtime.BailErrorStrategy]], which is used in two-stage parsing:
* This error strategy allows the first stage of two-stage parsing to immediately terminate
* if an error is encountered, and immediately fall back to the second stage. In addition to
* avoiding wasted work by attempting to recover from errors here, the empty implementation
* of sync improves the performance of the first stage.
*/
class SparkParserBailErrorStrategy() extends SparkParserErrorStrategy {

/**
* Instead of recovering from exception e, re-throw it wrapped
* in a [[ParseCancellationException]] so it is not caught by the
* rule function catches. Use [[Exception#getCause]] to get the
* original [[RecognitionException]].
*/
override def recover(recognizer: Parser, e: RecognitionException): Unit = {
var context = recognizer.getContext
while (context != null) {
context.exception = e
context = context.getParent
}
throw new ParseCancellationException(e)
}

/**
* Make sure we don't attempt to recover inline; if the parser
* successfully recovers, it won't throw an exception.
*/
@throws[RecognitionException]
override def recoverInline(recognizer: Parser): Token = {
val e = new InputMismatchException(recognizer)
var context = recognizer.getContext
while (context != null) {
context.exception = e
context = context.getParent
}
throw new ParseCancellationException(e)
}

/** Make sure we don't attempt to recover from problems in subrules. */
override def sync(recognizer: Parser): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DDLParserSuite extends AnalysisTest {
checkError(
exception = parseException(sql),
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "':'", "hint" -> ": extra input ':'"))
parameters = Map("error" -> "':'", "hint" -> ""))
}

test("create/replace table - with IF NOT EXISTS") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ class ExpressionParserSuite extends AnalysisTest {
checkError(
exception = parseException(".e3"),
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'.'", "hint" -> ": extra input '.'"))
parameters = Map("error" -> "'.'", "hint" -> ""))

// Tiny Int Literal
assertEqual("10Y", Literal(10.toByte))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ class PlanParserSuite extends AnalysisTest {
stop = 25))
}

test("SPARK-42552: select and union without parentheses") {
val plan = Distinct(OneRowRelation().select(Literal(1))
.union(OneRowRelation().select(Literal(1))))
assertEqual("select 1 union select 1", plan)
}

test("set operations") {
val a = table("a").select(star())
val b = table("b").select(star())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class TableSchemaParserSuite extends SparkFunSuite {
checkError(
exception = parseException("a INT,, b long"),
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "','", "hint" -> ": extra input ','"))
parameters = Map("error" -> "','", "hint" -> ""))
checkError(
exception = parseException("a INT, b long,,"),
errorClass = "PARSE_SYNTAX_ERROR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,33 @@ Union false, false
-- !query
SELECT 1 AS three UNION SELECT 2 UNION SELECT 3 ORDER BY 1
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Sort [three#x ASC NULLS FIRST], true
+- Distinct
+- Union false, false
:- Distinct
: +- Union false, false
: :- Project [1 AS three#x]
: : +- OneRowRelation
: +- Project [2 AS 2#x]
: +- OneRowRelation
+- Project [3 AS 3#x]
+- OneRowRelation


-- !query
SELECT 1 AS two UNION SELECT 2 UNION SELECT 2 ORDER BY 1
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Sort [two#x ASC NULLS FIRST], true
+- Distinct
+- Union false, false
:- Distinct
: +- Union false, false
: :- Project [1 AS two#x]
: : +- OneRowRelation
: +- Project [2 AS 2#x]
: +- OneRowRelation
+- Project [2 AS 2#x]
+- OneRowRelation


-- !query
Expand Down Expand Up @@ -221,29 +225,37 @@ Sort [two#x ASC NULLS FIRST], true
-- !query
SELECT 1.1 AS three UNION SELECT 2 UNION SELECT 3 ORDER BY 1
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Sort [three#x ASC NULLS FIRST], true
+- Distinct
+- Union false, false
:- Distinct
: +- Union false, false
: :- Project [cast(three#x as decimal(11,1)) AS three#x]
: : +- Project [1.1 AS three#x]
: : +- OneRowRelation
: +- Project [cast(2#x as decimal(11,1)) AS 2#x]
: +- Project [2 AS 2#x]
: +- OneRowRelation
+- Project [cast(3#x as decimal(11,1)) AS 3#x]
+- Project [3 AS 3#x]
+- OneRowRelation


-- !query
SELECT double(1.1) AS two UNION SELECT 2 UNION SELECT double(2.0) ORDER BY 1
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Sort [two#x ASC NULLS FIRST], true
+- Distinct
+- Union false, false
:- Distinct
: +- Union false, false
: :- Project [cast(1.1 as double) AS two#x]
: : +- OneRowRelation
: +- Project [cast(2#x as double) AS 2#x]
: +- Project [2 AS 2#x]
: +- OneRowRelation
+- Project [cast(2.0 as double) AS 2.0#x]
+- OneRowRelation


-- !query
Expand Down Expand Up @@ -606,57 +618,59 @@ Sort [q1#xL ASC NULLS FIRST], true
-- !query
(SELECT 1,2,3 UNION SELECT 4,5,6) INTERSECT SELECT 4,5,6
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Intersect false
:- Distinct
: +- Union false, false
: :- Project [1 AS 1#x, 2 AS 2#x, 3 AS 3#x]
: : +- OneRowRelation
: +- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
: +- OneRowRelation
+- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
+- OneRowRelation


-- !query
(SELECT 1,2,3 UNION SELECT 4,5,6 ORDER BY 1,2) INTERSECT SELECT 4,5,6
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Intersect false
:- Sort [1#x ASC NULLS FIRST, 2#x ASC NULLS FIRST], true
: +- Distinct
: +- Union false, false
: :- Project [1 AS 1#x, 2 AS 2#x, 3 AS 3#x]
: : +- OneRowRelation
: +- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
: +- OneRowRelation
+- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
+- OneRowRelation


-- !query
(SELECT 1,2,3 UNION SELECT 4,5,6) EXCEPT SELECT 4,5,6
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Except false
:- Distinct
: +- Union false, false
: :- Project [1 AS 1#x, 2 AS 2#x, 3 AS 3#x]
: : +- OneRowRelation
: +- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
: +- OneRowRelation
+- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
+- OneRowRelation


-- !query
(SELECT 1,2,3 UNION SELECT 4,5,6 ORDER BY 1,2) EXCEPT SELECT 4,5,6
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Except false
:- Sort [1#x ASC NULLS FIRST, 2#x ASC NULLS FIRST], true
: +- Distinct
: +- Union false, false
: :- Project [1 AS 1#x, 2 AS 2#x, 3 AS 3#x]
: : +- OneRowRelation
: +- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
: +- OneRowRelation
+- Project [4 AS 4#x, 5 AS 5#x, 6 AS 6#x]
+- OneRowRelation


-- !query
Expand Down Expand Up @@ -1164,15 +1178,14 @@ Except All true
-- !query
SELECT cast('3.4' as decimal(38, 18)) UNION SELECT 'foo'
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"error" : "'SELECT'",
"hint" : ""
}
}
Distinct
+- Union false, false
:- Project [cast(CAST(3.4 AS DECIMAL(38,18))#x as double) AS CAST(3.4 AS DECIMAL(38,18))#x]
: +- Project [cast(3.4 as decimal(38,18)) AS CAST(3.4 AS DECIMAL(38,18))#x]
: +- OneRowRelation
+- Project [cast(foo#x as double) AS foo#x]
+- Project [foo AS foo#x]
+- OneRowRelation


-- !query
Expand Down
Loading

0 comments on commit f8604ad

Please sign in to comment.