Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions parser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ func (u *UUID) Accept(visitor ASTVisitor) error {
type CreateDatabase struct {
CreatePos Pos // position of CREATE keyword
StatementEnd Pos
Name *Ident
Name Expr
IfNotExists bool // true if 'IF NOT EXISTS' is specified
OnCluster *OnClusterExpr
Engine *EngineExpr
Expand Down Expand Up @@ -1749,19 +1749,19 @@ func (n *NotNullLiteral) Accept(visitor ASTVisitor) error {
}

type NestedIdentifier struct {
Ident *Ident
DotIdent *Ident
Ident Expr
DotIdent Expr
}

func (n *NestedIdentifier) Pos() Pos {
return n.Ident.NamePos
return n.Ident.Pos()
}

func (n *NestedIdentifier) End() Pos {
if n.DotIdent != nil {
return n.DotIdent.NameEnd
return n.DotIdent.End()
}
return n.Ident.NameEnd
return n.Ident.End()
}

func (n *NestedIdentifier) String(int) string {
Expand Down Expand Up @@ -1835,19 +1835,19 @@ func (c *ColumnIdentifier) Accept(visitor ASTVisitor) error {
}

type TableIdentifier struct {
Database *Ident
Table *Ident
Database Expr
Table Expr
}

func (t *TableIdentifier) Pos() Pos {
if t.Database != nil {
return t.Database.NamePos
return t.Database.Pos()
}
return t.Table.NamePos
return t.Table.Pos()
}

func (t *TableIdentifier) End() Pos {
return t.Table.NameEnd
return t.Table.End()
}

func (t *TableIdentifier) String(int) string {
Expand Down Expand Up @@ -1972,12 +1972,12 @@ func (t *TableArgListExpr) Accept(visitor ASTVisitor) error {
}

type TableFunctionExpr struct {
Name *Ident
Name Expr
Args *TableArgListExpr
}

func (t *TableFunctionExpr) Pos() Pos {
return t.Name.NamePos
return t.Name.Pos()
}

func (t *TableFunctionExpr) End() Pos {
Expand Down
10 changes: 5 additions & 5 deletions parser/parser_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ func (p *Parser) parseIdentOrStar() (*Ident, error) {
}
}

func (p *Parser) tryParseDotIdent() (*Ident, error) {
func (p *Parser) tryParseDotIdent(pos Pos) (Expr, error) {
if p.tryConsumeTokenKind(".") == nil {
return nil, nil // nolint
}
return p.parseIdent()
return p.parseIdentOrString(pos)
}

func (p *Parser) parseUUID() (*UUID, error) {
Expand Down Expand Up @@ -257,12 +257,12 @@ func (p *Parser) parseLiteral(pos Pos) (Literal, error) {
}
}

func (p *Parser) ParseNestedIdentifier(_ Pos) (*NestedIdentifier, error) {
ident, err := p.parseIdent()
func (p *Parser) ParseNestedIdentifier(pos Pos) (*NestedIdentifier, error) {
ident, err := p.parseIdentOrString(pos)
if err != nil {
return nil, err
}
dotIdent, err := p.tryParseDotIdent()
dotIdent, err := p.tryParseDotIdent(p.Pos())
if err != nil {
return nil, err
}
Expand Down
4 changes: 1 addition & 3 deletions parser/parser_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ func (p *Parser) parseTableExpr(pos Pos) (*TableExpr, error) {
var expr Expr
var err error
switch {
case p.matchTokenKind(TokenString):
expr, err = p.parseString(p.Pos())
case p.matchTokenKind(TokenIdent):
case p.matchTokenKind(TokenString), p.matchTokenKind(TokenIdent):
// table name
tableIdentifier, err := p.parseTableIdentifier(p.Pos())
if err != nil {
Expand Down
14 changes: 7 additions & 7 deletions parser/parser_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (p *Parser) parseCreateDatabase(pos Pos) (*CreateDatabase, error) {
return nil, err
}
// parse database name
name, err := p.parseIdent()
name, err := p.parseIdentOrString(p.Pos())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -254,12 +254,12 @@ func (p *Parser) parseIdentOrFunction(_ Pos) (Expr, error) {
return ident, nil
}

func (p *Parser) parseTableIdentifier(_ Pos) (*TableIdentifier, error) {
ident, err := p.parseIdent()
func (p *Parser) parseTableIdentifier(pos Pos) (*TableIdentifier, error) {
ident, err := p.parseIdentOrString(pos)
if err != nil {
return nil, err
}
dotIdent, err := p.tryParseDotIdent()
dotIdent, err := p.tryParseDotIdent(p.Pos())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -306,13 +306,13 @@ func (p *Parser) parseTableSchemaExpr(pos Pos) (*TableSchemaExpr, error) {
switch {
case p.matchTokenKind("."):
// it's a database.table
dotIdent, err := p.tryParseDotIdent()
dotIdent, err := p.tryParseDotIdent(p.Pos())
if err != nil {
return nil, err
}
return &TableSchemaExpr{
SchemaPos: pos,
SchemaEnd: dotIdent.NameEnd,
SchemaEnd: dotIdent.End(),
AliasTable: &TableIdentifier{
Database: ident,
Table: dotIdent,
Expand Down Expand Up @@ -475,7 +475,7 @@ func (p *Parser) parseTableArgExpr(pos Pos) (Expr, error) {
switch {
// nest identifier
case p.matchTokenKind("."):
dotIdent, err := p.tryParseDotIdent()
dotIdent, err := p.tryParseDotIdent(p.Pos())
if err != nil {
return nil, err
}
Expand Down
11 changes: 11 additions & 0 deletions parser/testdata/query/format/select_with_literal_table_name.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Origin SQL:
select table_name from "information_schema"."tables" limit 1;


-- Format SQL:

SELECT
table_name
FROM
'information_schema'.'tables'
LIMIT 1;
18 changes: 12 additions & 6 deletions parser/testdata/query/output/select_with_join_only.sql.golden.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
"TableEnd": 17,
"Alias": null,
"Expr": {
"LiteralPos": 15,
"LiteralEnd": 17,
"Literal": "t1"
"Database": null,
"Table": {
"LiteralPos": 15,
"LiteralEnd": 17,
"Literal": "t1"
}
},
"HasFinal": false
},
Expand All @@ -37,9 +40,12 @@
"TableEnd": 27,
"Alias": null,
"Expr": {
"LiteralPos": 25,
"LiteralEnd": 27,
"Literal": "t2"
"Database": null,
"Table": {
"LiteralPos": 25,
"LiteralEnd": 27,
"Literal": "t2"
}
},
"HasFinal": false
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
[
{
"SelectPos": 0,
"StatementEnd": 60,
"With": null,
"Top": null,
"SelectColumns": {
"ListPos": 7,
"ListEnd": 17,
"HasDistinct": false,
"Items": [
{
"Name": "table_name",
"Unquoted": false,
"NamePos": 7,
"NameEnd": 17
}
]
},
"From": {
"FromPos": 18,
"Expr": {
"TablePos": 24,
"TableEnd": 51,
"Alias": null,
"Expr": {
"Database": {
"LiteralPos": 24,
"LiteralEnd": 42,
"Literal": "information_schema"
},
"Table": {
"LiteralPos": 45,
"LiteralEnd": 51,
"Literal": "tables"
}
},
"HasFinal": false
}
},
"ArrayJoin": null,
"Window": null,
"Prewhere": null,
"Where": null,
"GroupBy": null,
"WithTotal": false,
"Having": null,
"OrderBy": null,
"LimitBy": null,
"Limit": {
"LimitPos": 53,
"Limit": {
"NumPos": 59,
"NumEnd": 60,
"Literal": "1",
"Base": 10
},
"Offset": null
},
"Settings": null,
"UnionAll": null,
"UnionDistinct": null,
"Except": null
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@
"TableEnd": 48,
"Alias": null,
"Expr": {
"LiteralPos": 45,
"LiteralEnd": 48,
"Literal": "abc"
"Database": null,
"Table": {
"LiteralPos": 45,
"LiteralEnd": 48,
"Literal": "abc"
}
},
"HasFinal": false
}
Expand Down
1 change: 1 addition & 0 deletions parser/testdata/query/select_with_literal_table_name.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select table_name from "information_schema"."tables" limit 1;
8 changes: 4 additions & 4 deletions parser/visitor_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package parser

import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"

"fmt"
"github.com/sebdah/goldie/v2"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -58,8 +58,8 @@ type simpleRewriteVisitor struct {
}

func (v *simpleRewriteVisitor) VisitTableIdentifier(expr *TableIdentifier) error {
if expr.Table.Name == "group_by_all" {
expr.Table.Name = "hack"
if expr.Table.String(0) == "group_by_all" {
expr.Table = &Ident{Name: "hack"}
}
return nil
}
Expand Down Expand Up @@ -95,7 +95,7 @@ type nestedRewriteVisitor struct {
}

func (v *nestedRewriteVisitor) VisitTableIdentifier(expr *TableIdentifier) error {
expr.Table.Name = fmt.Sprintf("table%d", len(v.stack))
expr.Table = &Ident{Name: fmt.Sprintf("table%d", len(v.stack))}
return nil
}

Expand Down