Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 6dacde8

Browse files
authored
Merge pull request #181 from mcarmonaa/fix/rows-indexes-in-getfield
sql/analyzer: resolve correctly GetField expressions indexes
2 parents 16849ea + cfc5c19 commit 6dacde8

File tree

13 files changed

+190
-50
lines changed

13 files changed

+190
-50
lines changed

engine_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ var queries = []struct {
196196
`SELECT i AS foo FROM mytable WHERE foo NOT IN (1, 2, 5)`,
197197
[]sql.Row{{int64(3)}},
198198
},
199+
{
200+
`SELECT * FROM tabletest, mytable mt INNER JOIN othertable ot ON mt.i = ot.i2`,
201+
[]sql.Row{
202+
{"a", int32(1), int64(1), "first row", "third", int64(1)},
203+
{"a", int32(1), int64(2), "second row", "second", int64(2)},
204+
{"a", int32(1), int64(3), "third row", "first", int64(3)},
205+
{"b", int32(2), int64(1), "first row", "third", int64(1)},
206+
{"b", int32(2), int64(2), "second row", "second", int64(2)},
207+
{"b", int32(2), int64(3), "third row", "first", int64(3)},
208+
{"c", int32(3), int64(1), "first row", "third", int64(1)},
209+
{"c", int32(3), int64(2), "second row", "second", int64(2)},
210+
{"c", int32(3), int64(3), "third row", "first", int64(3)},
211+
},
212+
},
199213
}
200214

201215
func TestQueries(t *testing.T) {

sql/analyzer/rules.go

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,6 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
410410
})
411411
}
412412

413-
type columnInfo struct {
414-
idx int
415-
col *sql.Column
416-
}
417-
418413
// maybeAlias is a wrapper on UnresolvedColumn used only to defer the
419414
// resolution of the column because it could be an alias and that
420415
// phase of the analyzer has not run yet.
@@ -444,20 +439,23 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
444439
return n, nil
445440
}
446441

447-
colMap := make(map[string][]columnInfo)
448-
idx := 0
442+
colMap := make(map[string][]*sql.Column)
449443
for _, child := range n.Children() {
450444
if !child.Resolved() {
451445
return n, nil
452446
}
453447

454448
for _, col := range child.Schema() {
455-
colMap[col.Name] = append(colMap[col.Name], columnInfo{idx, col})
456-
idx++
449+
colMap[col.Name] = append(colMap[col.Name], col)
457450
}
458451
}
459452

460-
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
453+
expressioner, ok := n.(sql.Expressioner)
454+
if !ok {
455+
return n, nil
456+
}
457+
458+
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
461459
a.Log("transforming expression of type: %T", e)
462460
if e.Resolved() {
463461
return e, nil
@@ -468,7 +466,7 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
468466
return e, nil
469467
}
470468

471-
columnsInfo, ok := colMap[uc.Name()]
469+
columns, ok := colMap[uc.Name()]
472470
if !ok {
473471
if uc.Table() != "" {
474472
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
@@ -482,11 +480,11 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
482480
}
483481
}
484482

485-
var ci columnInfo
483+
var col *sql.Column
486484
var found bool
487-
for _, c := range columnsInfo {
488-
if c.col.Source == uc.Table() {
489-
ci = c
485+
for _, c := range columns {
486+
if c.Source == uc.Table() {
487+
col = c
490488
found = true
491489
break
492490
}
@@ -505,14 +503,30 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
505503
}
506504
}
507505

508-
a.Log("column resolved to %q.%q", ci.col.Source, ci.col.Name)
506+
var schema sql.Schema
507+
switch n := n.(type) {
508+
// If expressioner and unary node we must take the
509+
// child's schema to correctly select the indexes
510+
// in the row is going to be evaluated in this node
511+
case *plan.Project, *plan.Filter, *plan.GroupBy, *plan.Sort:
512+
schema = n.Children()[0].Schema()
513+
default:
514+
schema = n.Schema()
515+
}
516+
517+
idx := schema.IndexOf(col.Name, col.Source)
518+
if idx < 0 {
519+
return nil, ErrColumnNotFound.New(col.Name)
520+
}
521+
522+
a.Log("column resolved to %q.%q", col.Source, col.Name)
509523

510524
return expression.NewGetFieldWithTable(
511-
ci.idx,
512-
ci.col.Type,
513-
ci.col.Source,
514-
ci.col.Name,
515-
ci.col.Nullable,
525+
idx,
526+
col.Type,
527+
col.Source,
528+
col.Name,
529+
col.Nullable,
516530
), nil
517531
})
518532
})

sql/analyzer/validation_rules.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func validateIndexCreation(ctx *sql.Context, n sql.Node) error {
154154
expression.Inspect(expr, func(e sql.Expression) bool {
155155
gf, ok := e.(*expression.GetField)
156156
if ok {
157-
if gf.Table() != table || !schema.Contains(gf.Name()) {
157+
if gf.Table() != table || !schema.Contains(gf.Name(), gf.Table()) {
158158
unknownColumns = append(unknownColumns, gf.Name())
159159
}
160160
}

sql/core.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ type Node interface {
110110
type Expressioner interface {
111111
// Expressions returns the list of expressions contained by the node.
112112
Expressions() []Expression
113+
// TransformExpressions applies for each expression in this node
114+
// the expression's TransformUp method with the given function, and
115+
// return a new node with the transformed expressions.
116+
TransformExpressions(TransformExprFunc) (Node, error)
113117
}
114118

115119
// Table represents a SQL table.

sql/expression/get_field.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,37 +32,40 @@ func NewGetFieldWithTable(index int, fieldType sql.Type, table, fieldName string
3232
}
3333
}
3434

35+
// Index returns the index where the GetField will look for the value from a sql.Row.
36+
func (p *GetField) Index() int { return p.fieldIndex }
37+
3538
// Children implements the Expression interface.
36-
func (GetField) Children() []sql.Expression {
39+
func (*GetField) Children() []sql.Expression {
3740
return nil
3841
}
3942

4043
// Table returns the name of the field table.
41-
func (p GetField) Table() string { return p.table }
44+
func (p *GetField) Table() string { return p.table }
4245

4346
// Resolved implements the Expression interface.
44-
func (p GetField) Resolved() bool {
47+
func (p *GetField) Resolved() bool {
4548
return true
4649
}
4750

4851
// Name implements the Nameable interface.
49-
func (p GetField) Name() string { return p.name }
52+
func (p *GetField) Name() string { return p.name }
5053

5154
// IsNullable returns whether the field is nullable or not.
52-
func (p GetField) IsNullable() bool {
55+
func (p *GetField) IsNullable() bool {
5356
return p.nullable
5457
}
5558

5659
// Type returns the type of the field.
57-
func (p GetField) Type() sql.Type {
60+
func (p *GetField) Type() sql.Type {
5861
return p.fieldType
5962
}
6063

6164
// ErrIndexOutOfBounds is returned when the field index is out of the bounds.
6265
var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns")
6366

6467
// Eval implements the Expression interface.
65-
func (p GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
68+
func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
6669
if p.fieldIndex < 0 || p.fieldIndex >= len(row) {
6770
return nil, ErrIndexOutOfBounds.New(p.fieldIndex, len(row))
6871
}
@@ -75,7 +78,7 @@ func (p *GetField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error)
7578
return f(&n)
7679
}
7780

78-
func (p GetField) String() string {
81+
func (p *GetField) String() string {
7982
if p.table == "" {
8083
return p.name
8184
}

sql/plan/filter.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package plan
22

3-
import "gopkg.in/src-d/go-mysql-server.v0/sql"
3+
import (
4+
"gopkg.in/src-d/go-mysql-server.v0/sql"
5+
)
46

57
// Filter skips rows that don't match a certain expression.
68
type Filter struct {
@@ -58,18 +60,28 @@ func (p *Filter) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, erro
5860
return NewFilter(expr, child), nil
5961
}
6062

61-
func (p Filter) String() string {
63+
func (p *Filter) String() string {
6264
pr := sql.NewTreePrinter()
6365
_ = pr.WriteNode("Filter(%s)", p.Expression)
6466
_ = pr.WriteChildren(p.Child.String())
6567
return pr.String()
6668
}
6769

6870
// Expressions implements the Expressioner interface.
69-
func (p Filter) Expressions() []sql.Expression {
71+
func (p *Filter) Expressions() []sql.Expression {
7072
return []sql.Expression{p.Expression}
7173
}
7274

75+
// TransformExpressions implements the Expressioner interface.
76+
func (p *Filter) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
77+
e, err := p.Expression.TransformUp(f)
78+
if err != nil {
79+
return nil, err
80+
}
81+
82+
return NewFilter(e, p.Child), nil
83+
}
84+
7385
// FilterIter is an iterator that filters another iterator and skips rows that
7486
// don't match the given condition.
7587
type FilterIter struct {

sql/plan/group_by.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func (p *GroupBy) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, err
107107
return NewGroupBy(aggregate, grouping, child), nil
108108
}
109109

110-
func (p GroupBy) String() string {
110+
func (p *GroupBy) String() string {
111111
pr := sql.NewTreePrinter()
112112
_ = pr.WriteNode("GroupBy")
113113

@@ -130,13 +130,28 @@ func (p GroupBy) String() string {
130130
}
131131

132132
// Expressions implements the Expressioner interface.
133-
func (p GroupBy) Expressions() []sql.Expression {
133+
func (p *GroupBy) Expressions() []sql.Expression {
134134
var exprs []sql.Expression
135135
exprs = append(exprs, p.Aggregate...)
136136
exprs = append(exprs, p.Grouping...)
137137
return exprs
138138
}
139139

140+
// TransformExpressions implements the Expressioner interface.
141+
func (p *GroupBy) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
142+
agg, err := transformExpressionsUp(f, p.Aggregate)
143+
if err != nil {
144+
return nil, err
145+
}
146+
147+
group, err := transformExpressionsUp(f, p.Grouping)
148+
if err != nil {
149+
return nil, err
150+
}
151+
152+
return NewGroupBy(agg, group, p.Child), nil
153+
}
154+
140155
type groupByIter struct {
141156
p *GroupBy
142157
childIter sql.RowIter

sql/plan/innerjoin.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,24 @@ func (j *InnerJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, e
106106
return NewInnerJoin(left, right, cond), nil
107107
}
108108

109-
func (j InnerJoin) String() string {
109+
func (j *InnerJoin) String() string {
110110
pr := sql.NewTreePrinter()
111111
_ = pr.WriteNode("InnerJoin(%s)", j.Cond)
112112
_ = pr.WriteChildren(j.Left.String(), j.Right.String())
113113
return pr.String()
114114
}
115115

116116
// Expressions implements the Expressioner interface.
117-
func (j InnerJoin) Expressions() []sql.Expression {
117+
func (j *InnerJoin) Expressions() []sql.Expression {
118118
return []sql.Expression{j.Cond}
119119
}
120+
121+
// TransformExpressions implements the Expressioner interface.
122+
func (j *InnerJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
123+
cond, err := j.Cond.TransformUp(f)
124+
if err != nil {
125+
return nil, err
126+
}
127+
128+
return NewInnerJoin(j.Left, j.Right, cond), nil
129+
}

sql/plan/project.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (p *Project) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, err
9393
return NewProject(exprs, child), nil
9494
}
9595

96-
func (p Project) String() string {
96+
func (p *Project) String() string {
9797
pr := sql.NewTreePrinter()
9898
var exprs = make([]string, len(p.Projections))
9999
for i, expr := range p.Projections {
@@ -105,10 +105,20 @@ func (p Project) String() string {
105105
}
106106

107107
// Expressions implements the Expressioner interface.
108-
func (p Project) Expressions() []sql.Expression {
108+
func (p *Project) Expressions() []sql.Expression {
109109
return p.Projections
110110
}
111111

112+
// TransformExpressions implements the Expressioner interface.
113+
func (p *Project) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
114+
projects, err := transformExpressionsUp(f, p.Projections)
115+
if err != nil {
116+
return nil, err
117+
}
118+
119+
return NewProject(projects, p.Child), nil
120+
}
121+
112122
type iter struct {
113123
p *Project
114124
childIter sql.RowIter

sql/plan/pushdown.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func (t *PushdownProjectionAndFiltersTable) RowIter(ctx *sql.Context) (sql.RowIt
137137
return sql.NewSpanIter(span, iter), nil
138138
}
139139

140-
func (t PushdownProjectionAndFiltersTable) String() string {
140+
func (t *PushdownProjectionAndFiltersTable) String() string {
141141
pr := sql.NewTreePrinter()
142142
_ = pr.WriteNode("PushdownProjectionAndFiltersTable")
143143

@@ -161,9 +161,28 @@ func (t PushdownProjectionAndFiltersTable) String() string {
161161
}
162162

163163
// Expressions implements the Expressioner interface.
164-
func (t PushdownProjectionAndFiltersTable) Expressions() []sql.Expression {
164+
func (t *PushdownProjectionAndFiltersTable) Expressions() []sql.Expression {
165165
var exprs []sql.Expression
166166
exprs = append(exprs, t.Columns...)
167167
exprs = append(exprs, t.Filters...)
168168
return exprs
169169
}
170+
171+
// TransformExpressions implements the Expressioner interface.
172+
func (t *PushdownProjectionAndFiltersTable) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
173+
cols, err := transformExpressionsUp(f, t.Columns)
174+
if err != nil {
175+
return nil, err
176+
}
177+
178+
filters, err := transformExpressionsUp(f, t.Filters)
179+
if err != nil {
180+
return nil, err
181+
}
182+
183+
return NewPushdownProjectionAndFiltersTable(
184+
cols,
185+
filters,
186+
t.PushdownProjectionAndFiltersTable,
187+
), nil
188+
}

0 commit comments

Comments
 (0)