Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

sql/analyzer: resolve correctly GetField expressions indexes #181

Merged
merged 1 commit into from
May 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ var queries = []struct {
`SELECT i AS foo FROM mytable WHERE foo NOT IN (1, 2, 5)`,
[]sql.Row{{int64(3)}},
},
{
`SELECT * FROM tabletest, mytable mt INNER JOIN othertable ot ON mt.i = ot.i2`,
[]sql.Row{
{"a", int32(1), int64(1), "first row", "third", int64(1)},
{"a", int32(1), int64(2), "second row", "second", int64(2)},
{"a", int32(1), int64(3), "third row", "first", int64(3)},
{"b", int32(2), int64(1), "first row", "third", int64(1)},
{"b", int32(2), int64(2), "second row", "second", int64(2)},
{"b", int32(2), int64(3), "third row", "first", int64(3)},
{"c", int32(3), int64(1), "first row", "third", int64(1)},
{"c", int32(3), int64(2), "second row", "second", int64(2)},
{"c", int32(3), int64(3), "third row", "first", int64(3)},
},
},
}

func TestQueries(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ func TestAddRule(t *testing.T) {
require := require.New(t)

a := New(nil)
require.Len(a.Rules, 13)
require.Len(a.Rules, len(DefaultRules))
a.AddRule("foo", pushdown)
require.Len(a.Rules, 14)
require.Len(a.Rules, len(DefaultRules)+1)
}

func TestAddValidationRule(t *testing.T) {
Expand Down
56 changes: 35 additions & 21 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,6 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
})
}

type columnInfo struct {
idx int
col *sql.Column
}

// maybeAlias is a wrapper on UnresolvedColumn used only to defer the
// resolution of the column because it could be an alias and that
// phase of the analyzer has not run yet.
Expand Down Expand Up @@ -342,20 +337,23 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
return n, nil
}

colMap := make(map[string][]columnInfo)
idx := 0
colMap := make(map[string][]*sql.Column)
for _, child := range n.Children() {
if !child.Resolved() {
return n, nil
}

for _, col := range child.Schema() {
colMap[col.Name] = append(colMap[col.Name], columnInfo{idx, col})
idx++
colMap[col.Name] = append(colMap[col.Name], col)
}
}

return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
expressioner, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
if n.Resolved() {
return e, nil
Expand All @@ -366,7 +364,7 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
return e, nil
}

columnsInfo, ok := colMap[uc.Name()]
columns, ok := colMap[uc.Name()]
if !ok {
if uc.Table() != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
Expand All @@ -380,11 +378,11 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
}
}

var ci columnInfo
var col *sql.Column
var found bool
for _, c := range columnsInfo {
if c.col.Source == uc.Table() {
ci = c
for _, c := range columns {
if c.Source == uc.Table() {
col = c
found = true
break
}
Expand All @@ -403,14 +401,30 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
}
}

a.Log("column resolved to %q.%q", ci.col.Source, ci.col.Name)
var schema sql.Schema
switch n := n.(type) {
// If expressioner and unary node we must take the
// child's schema to correctly select the indexes
// in the row is going to be evaluated in this node
case *plan.Project, *plan.Filter, *plan.GroupBy, *plan.Sort:
schema = n.Children()[0].Schema()
default:
schema = n.Schema()
}

idx := schema.IndexOf(col.Name, col.Source)
if idx < 0 {
return nil, ErrColumnNotFound.New(col.Name)
}

a.Log("column resolved to %q.%q", col.Source, col.Name)

return expression.NewGetFieldWithTable(
ci.idx,
ci.col.Type,
ci.col.Source,
ci.col.Name,
ci.col.Nullable,
idx,
col.Type,
col.Source,
col.Name,
col.Nullable,
), nil
})
})
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func validateIndexCreation(ctx *sql.Context, n sql.Node) error {
expression.Inspect(expr, func(e sql.Expression) bool {
gf, ok := e.(*expression.GetField)
if ok {
if gf.Table() != table || !schema.Contains(gf.Name()) {
if gf.Table() != table || !schema.Contains(gf.Name(), gf.Table()) {
unknownColumns = append(unknownColumns, gf.Name())
}
}
Expand Down
4 changes: 4 additions & 0 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ type Node interface {
type Expressioner interface {
// Expressions returns the list of expressions contained by the node.
Expressions() []Expression
// TransformExpressions applies for each expression in this node
// the expression's TransformUp method with the given function, and
// return a new node with the transformed expressions.
TransformExpressions(TransformExprFunc) (Node, error)
}

// Table represents a SQL table.
Expand Down
19 changes: 11 additions & 8 deletions sql/expression/get_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,40 @@ func NewGetFieldWithTable(index int, fieldType sql.Type, table, fieldName string
}
}

// Index returns the index where the GetField will look for the value from a sql.Row.
func (p *GetField) Index() int { return p.fieldIndex }

// Children implements the Expression interface.
func (GetField) Children() []sql.Expression {
func (*GetField) Children() []sql.Expression {
return nil
}

// Table returns the name of the field table.
func (p GetField) Table() string { return p.table }
func (p *GetField) Table() string { return p.table }

// Resolved implements the Expression interface.
func (p GetField) Resolved() bool {
func (p *GetField) Resolved() bool {
return true
}

// Name implements the Nameable interface.
func (p GetField) Name() string { return p.name }
func (p *GetField) Name() string { return p.name }

// IsNullable returns whether the field is nullable or not.
func (p GetField) IsNullable() bool {
func (p *GetField) IsNullable() bool {
return p.nullable
}

// Type returns the type of the field.
func (p GetField) Type() sql.Type {
func (p *GetField) Type() sql.Type {
return p.fieldType
}

// ErrIndexOutOfBounds is returned when the field index is out of the bounds.
var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns")

// Eval implements the Expression interface.
func (p GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if p.fieldIndex < 0 || p.fieldIndex >= len(row) {
return nil, ErrIndexOutOfBounds.New(p.fieldIndex, len(row))
}
Expand All @@ -75,7 +78,7 @@ func (p *GetField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error)
return f(&n)
}

func (p GetField) String() string {
func (p *GetField) String() string {
if p.table == "" {
return p.name
}
Expand Down
18 changes: 15 additions & 3 deletions sql/plan/filter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package plan

import "gopkg.in/src-d/go-mysql-server.v0/sql"
import (
"gopkg.in/src-d/go-mysql-server.v0/sql"
)

// Filter skips rows that don't match a certain expression.
type Filter struct {
Expand Down Expand Up @@ -58,18 +60,28 @@ func (p *Filter) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, erro
return NewFilter(expr, child), nil
}

func (p Filter) String() string {
func (p *Filter) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("Filter(%s)", p.Expression)
_ = pr.WriteChildren(p.Child.String())
return pr.String()
}

// Expressions implements the Expressioner interface.
func (p Filter) Expressions() []sql.Expression {
func (p *Filter) Expressions() []sql.Expression {
return []sql.Expression{p.Expression}
}

// TransformExpressions implements the Expressioner interface.
func (p *Filter) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
e, err := p.Expression.TransformUp(f)
if err != nil {
return nil, err
}

return NewFilter(e, p.Child), nil
}

// FilterIter is an iterator that filters another iterator and skips rows that
// don't match the given condition.
type FilterIter struct {
Expand Down
19 changes: 17 additions & 2 deletions sql/plan/group_by.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (p *GroupBy) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, err
return NewGroupBy(aggregate, grouping, child), nil
}

func (p GroupBy) String() string {
func (p *GroupBy) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("GroupBy")

Expand All @@ -130,13 +130,28 @@ func (p GroupBy) String() string {
}

// Expressions implements the Expressioner interface.
func (p GroupBy) Expressions() []sql.Expression {
func (p *GroupBy) Expressions() []sql.Expression {
var exprs []sql.Expression
exprs = append(exprs, p.Aggregate...)
exprs = append(exprs, p.Grouping...)
return exprs
}

// TransformExpressions implements the Expressioner interface.
func (p *GroupBy) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
agg, err := transformExpressionsUp(f, p.Aggregate)
if err != nil {
return nil, err
}

group, err := transformExpressionsUp(f, p.Grouping)
if err != nil {
return nil, err
}

return NewGroupBy(agg, group, p.Child), nil
}

type groupByIter struct {
p *GroupBy
childIter sql.RowIter
Expand Down
14 changes: 12 additions & 2 deletions sql/plan/innerjoin.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,24 @@ func (j *InnerJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, e
return NewInnerJoin(left, right, cond), nil
}

func (j InnerJoin) String() string {
func (j *InnerJoin) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("InnerJoin(%s)", j.Cond)
_ = pr.WriteChildren(j.Left.String(), j.Right.String())
return pr.String()
}

// Expressions implements the Expressioner interface.
func (j InnerJoin) Expressions() []sql.Expression {
func (j *InnerJoin) Expressions() []sql.Expression {
return []sql.Expression{j.Cond}
}

// TransformExpressions implements the Expressioner interface.
func (j *InnerJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
cond, err := j.Cond.TransformUp(f)
if err != nil {
return nil, err
}

return NewInnerJoin(j.Left, j.Right, cond), nil
}
14 changes: 12 additions & 2 deletions sql/plan/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (p *Project) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, err
return NewProject(exprs, child), nil
}

func (p Project) String() string {
func (p *Project) String() string {
pr := sql.NewTreePrinter()
var exprs = make([]string, len(p.Projections))
for i, expr := range p.Projections {
Expand All @@ -105,10 +105,20 @@ func (p Project) String() string {
}

// Expressions implements the Expressioner interface.
func (p Project) Expressions() []sql.Expression {
func (p *Project) Expressions() []sql.Expression {
return p.Projections
}

// TransformExpressions implements the Expressioner interface.
func (p *Project) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
projects, err := transformExpressionsUp(f, p.Projections)
if err != nil {
return nil, err
}

return NewProject(projects, p.Child), nil
}

type iter struct {
p *Project
childIter sql.RowIter
Expand Down
23 changes: 21 additions & 2 deletions sql/plan/pushdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (t *PushdownProjectionAndFiltersTable) RowIter(ctx *sql.Context) (sql.RowIt
return sql.NewSpanIter(span, iter), nil
}

func (t PushdownProjectionAndFiltersTable) String() string {
func (t *PushdownProjectionAndFiltersTable) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("PushdownProjectionAndFiltersTable")

Expand All @@ -161,9 +161,28 @@ func (t PushdownProjectionAndFiltersTable) String() string {
}

// Expressions implements the Expressioner interface.
func (t PushdownProjectionAndFiltersTable) Expressions() []sql.Expression {
func (t *PushdownProjectionAndFiltersTable) Expressions() []sql.Expression {
var exprs []sql.Expression
exprs = append(exprs, t.Columns...)
exprs = append(exprs, t.Filters...)
return exprs
}

// TransformExpressions implements the Expressioner interface.
func (t *PushdownProjectionAndFiltersTable) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
cols, err := transformExpressionsUp(f, t.Columns)
if err != nil {
return nil, err
}

filters, err := transformExpressionsUp(f, t.Filters)
if err != nil {
return nil, err
}

return NewPushdownProjectionAndFiltersTable(
cols,
filters,
t.PushdownProjectionAndFiltersTable,
), nil
}
Loading