Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 8db045c

Browse files
committed
sql/analyzer: resolve GetField indexes in the resolveColumns rule
Signed-off-by: Manuel Carmona <[email protected]>
1 parent 7413160 commit 8db045c

File tree

3 files changed

+31
-75
lines changed

3 files changed

+31
-75
lines changed

sql/analyzer/rules.go

Lines changed: 23 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ var DefaultRules = []Rule{
2020
{"resolve_database", resolveDatabase},
2121
{"resolve_star", resolveStar},
2222
{"resolve_functions", resolveFunctions},
23-
{"resolve_getfield_indexes", resolveGetFieldIndexes},
2423
{"reorder_projection", reorderProjection},
2524
{"pushdown", pushdown},
2625
{"optimize_distinct", optimizeDistinct},
@@ -326,8 +325,6 @@ type column interface {
326325
sql.Expression
327326
}
328327

329-
const unresolvedGetFieldIndex = -1
330-
331328
func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
332329
span, ctx := ctx.Span("resolve_columns")
333330
defer span.Finish()
@@ -350,7 +347,12 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
350347
}
351348
}
352349

353-
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
350+
expressioner, ok := n.(sql.Expressioner)
351+
if !ok {
352+
return n, nil
353+
}
354+
355+
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
354356
a.Log("transforming expression of type: %T", e)
355357
if n.Resolved() {
356358
return e, nil
@@ -398,10 +400,26 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
398400
}
399401
}
400402

403+
var schema sql.Schema
404+
switch n := n.(type) {
405+
// If expressioner and unary node we must take the
406+
// child's schema to correctly select the indexes
407+
// in the row is going to be evaluated in this node
408+
case *plan.Project, *plan.Filter, *plan.GroupBy, *plan.Sort:
409+
schema = n.Children()[0].Schema()
410+
default:
411+
schema = n.Schema()
412+
}
413+
414+
idx, ok := schema.Contains(col.Name, col.Source)
415+
if !ok {
416+
return nil, ErrColumnNotFound.New(col.Name)
417+
}
418+
401419
a.Log("column resolved to %q.%q", col.Source, col.Name)
402420

403421
return expression.NewGetFieldWithTable(
404-
unresolvedGetFieldIndex,
422+
idx,
405423
col.Type,
406424
col.Source,
407425
col.Name,
@@ -805,66 +823,3 @@ func fixFieldIndexes(schema sql.Schema, exp sql.Expression) (sql.Expression, err
805823
return e, nil
806824
})
807825
}
808-
809-
// resolveGetFieldIndexes set the index attribute for each GetField expression with a value referring to
810-
// the node containing this expression. It should be run after resolveColumns
811-
func resolveGetFieldIndexes(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) {
812-
span, ctx := ctx.Span("fix join fields")
813-
defer span.Finish()
814-
815-
a.Log("fix join fields, node of type: %T", node)
816-
return node.TransformUp(func(n sql.Node) (sql.Node, error) {
817-
a.Log("transforming node of type: %T", n)
818-
if !n.Resolved() {
819-
return n, nil
820-
}
821-
822-
expressioner, ok := n.(sql.Expressioner)
823-
if !ok {
824-
return n, nil
825-
}
826-
827-
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
828-
a.Log("transforming expression of type: %T", e)
829-
if !e.Resolved() {
830-
return e, nil
831-
}
832-
833-
field, ok := e.(*expression.GetField)
834-
if !ok {
835-
return e, nil
836-
}
837-
838-
if field.Index() != unresolvedGetFieldIndex {
839-
return e, nil
840-
}
841-
842-
var schema sql.Schema
843-
switch n := n.(type) {
844-
// If expressioner and unary node we must take the
845-
// child's schema to correctly select the indexes
846-
// in the row is going to be evaluated in this node
847-
case *plan.Project, *plan.Filter, *plan.GroupBy, *plan.Sort:
848-
schema = n.Children()[0].Schema()
849-
default:
850-
schema = n.Schema()
851-
}
852-
853-
idx, ok := schema.Contains(field.Name(), field.Table())
854-
if !ok {
855-
return nil, ErrColumnNotFound.New(field.Name())
856-
}
857-
858-
fixedField := expression.NewGetFieldWithTable(
859-
idx,
860-
field.Type(),
861-
field.Table(),
862-
field.Name(),
863-
field.IsNullable(),
864-
)
865-
866-
a.Log("fixed expression %T for %T", field, n)
867-
return fixedField, nil
868-
})
869-
})
870-
}

sql/analyzer/rules_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,21 +472,21 @@ func TestReorderProjection(t *testing.T) {
472472
expression.NewGetField(1, sql.Int64, "bar", false),
473473
},
474474
plan.NewSort(
475-
[]plan.SortField{{Column: expression.NewGetField(unresolvedGetFieldIndex, sql.Int64, "foo", false)}},
475+
[]plan.SortField{{Column: expression.NewGetField(2, sql.Int64, "foo", false)}},
476476
plan.NewProject(
477477
[]sql.Expression{
478-
expression.NewGetFieldWithTable(unresolvedGetFieldIndex, sql.Int64, "mytable", "i", false),
479-
expression.NewGetField(unresolvedGetFieldIndex, sql.Int64, "bar", false),
478+
expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false),
479+
expression.NewGetField(1, sql.Int64, "bar", false),
480480
expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"),
481481
},
482482
plan.NewFilter(
483483
expression.NewEquals(
484484
expression.NewLiteral(1, sql.Int64),
485-
expression.NewGetField(unresolvedGetFieldIndex, sql.Int64, "bar", false),
485+
expression.NewGetField(1, sql.Int64, "bar", false),
486486
),
487487
plan.NewProject(
488488
[]sql.Expression{
489-
expression.NewGetFieldWithTable(unresolvedGetFieldIndex, sql.Int64, "mytable", "i", false),
489+
expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false),
490490
expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"),
491491
},
492492
table,
@@ -498,6 +498,7 @@ func TestReorderProjection(t *testing.T) {
498498

499499
result, err := f.Apply(sql.NewEmptyContext(), New(nil), node)
500500
require.NoError(err)
501+
501502
require.Equal(expected, result)
502503
}
503504

sql/plan/filter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ func (p *Filter) Expressions() []sql.Expression {
7474

7575
// TransformExpressions implements the Expressioner interface.
7676
func (p *Filter) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) {
77-
expression, err := p.Expression.TransformUp(f)
77+
e, err := p.Expression.TransformUp(f)
7878
if err != nil {
7979
return nil, err
8080
}
8181

82-
return NewFilter(expression, p.Child), nil
82+
return NewFilter(e, p.Child), nil
8383
}
8484

8585
// FilterIter is an iterator that filters another iterator and skips rows that

0 commit comments

Comments
 (0)