Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

sql: implement new API for node transformation #773

Merged
merged 1 commit into from
Jul 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mem/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"io"
"strconv"

errors "gopkg.in/src-d/go-errors.v1"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
errors "gopkg.in/src-d/go-errors.v1"
)

// Table represents an in-memory database table.
Expand Down Expand Up @@ -312,14 +312,14 @@ func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression {
var handled []sql.Expression
for _, f := range filters {
var hasOtherFields bool
_, _ = f.TransformUp(func(e sql.Expression) (sql.Expression, error) {
expression.Inspect(f, func(e sql.Expression) bool {
if e, ok := e.(*expression.GetField); ok {
if e.Table() != t.name || !t.schema.Contains(e.Name(), t.name) {
hasOtherFields = true
return false
}
}

return e, nil
return true
})

if !hasOtherFields {
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func reorderAggregations(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e

a.Log("reorder aggregations, node of type: %T", n)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.GroupBy:
if !hasHiddenAggregations(n.Aggregate...) {
Expand All @@ -38,7 +38,7 @@ func fixAggregations(projection, grouping []sql.Expression, child sql.Node) (sql

for i, p := range projection {
var transformed bool
e, err := p.TransformUp(func(e sql.Expression) (sql.Expression, error) {
e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) {
agg, ok := e.(sql.Aggregation)
if !ok {
return e, nil
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/assign_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
span, _ := ctx.Span("assign_catalog")
defer span.Finish()

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}
Expand Down
10 changes: 5 additions & 5 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
// replaced by.
var replacements = make(map[tableCol]string)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
Expand Down Expand Up @@ -48,7 +48,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
case *plan.GroupBy:
var aggregate = make([]sql.Expression, len(exp.Aggregate))
for i, a := range exp.Aggregate {
agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) {
agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
Expand All @@ -64,7 +64,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

var grouping = make([]sql.Expression, len(exp.Grouping))
for i, g := range exp.Grouping {
gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) {
gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false)
})
if err != nil {
Expand All @@ -77,7 +77,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
case *plan.Project:
var projections = make([]sql.Expression, len(exp.Projections))
for i, e := range exp.Projections {
expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
Expand All @@ -93,7 +93,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

result = plan.NewProject(projections, exp.Child)
default:
result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, n, replacements, nodeReplacements, expressions, false)
})
}
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func exprToTableFilters(expr sql.Expression) filters {
for _, expr := range splitExpression(expr) {
var seenTables = make(map[string]struct{})
var lastTable string
_, _ = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) {
expression.Inspect(expr, func(e sql.Expression) bool {
f, ok := e.(*expression.GetField)
if ok {
if _, ok := seenTables[f.Table()]; !ok {
Expand All @@ -29,7 +29,7 @@ func exprToTableFilters(expr sql.Expression) filters {
}
}

return e, nil
return true
})

if len(seenTables) == 1 {
Expand Down
21 changes: 11 additions & 10 deletions sql/analyzer/optimization_rules.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package analyzer

import (
"gopkg.in/src-d/go-errors.v1"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
"gopkg.in/src-d/go-errors.v1"
)

func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) {
Expand All @@ -17,7 +17,7 @@ func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, er

a.Log("erase projection, node of type: %T", node)

return node.TransformUp(func(node sql.Node) (sql.Node, error) {
return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
project, ok := node.(*plan.Project)
if ok && project.Schema().Equals(project.Child.Schema()) {
a.Log("project erased")
Expand All @@ -35,12 +35,13 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e
a.Log("optimize distinct, node of type: %T", node)
if n, ok := node.(*plan.Distinct); ok {
var isSorted bool
_, _ = node.TransformUp(func(node sql.Node) (sql.Node, error) {
plan.Inspect(n, func(node sql.Node) bool {
a.Log("checking for optimization in node of type: %T", node)
if _, ok := node.(*plan.Sort); ok {
isSorted = true
return false
}
return node, nil
return true
})

if isSorted {
Expand All @@ -65,7 +66,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
a.Log("reorder projection, node of type: %T", n)

// Then we transform the projection
return n.TransformUp(func(node sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) {
project, ok := node.(*plan.Project)
// When we transform the projection, the children will always be
// unresolved in the case we want to fix, as the reorder happens just
Expand All @@ -92,7 +93,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err

// And add projection nodes where needed in the child tree.
var didNeedReorder bool
child, err := project.Child.TransformUp(func(node sql.Node) (sql.Node, error) {
child, err := plan.TransformUp(project.Child, func(node sql.Node) (sql.Node, error) {
var requiredColumns []string
switch node := node.(type) {
case *plan.Sort, *plan.Filter:
Expand Down Expand Up @@ -200,7 +201,7 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.

a.Log("moving join conditions to filter, node of type: %T", n)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
join, ok := n.(*plan.InnerJoin)
if !ok {
return n, nil
Expand Down Expand Up @@ -268,7 +269,7 @@ func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.N

a.Log("removing unnecessary converts, node of type: %T", n)

return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) {
if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() {
return c.Child, nil
}
Expand Down Expand Up @@ -336,13 +337,13 @@ func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)

a.Log("evaluating filters, node of type: %T", node)

return node.TransformUp(func(node sql.Node) (sql.Node, error) {
return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
filter, ok := node.(*plan.Filter)
if !ok {
return node, nil
}

e, err := filter.Expression.TransformUp(func(e sql.Expression) (sql.Expression, error) {
e, err := expression.TransformUp(filter.Expression, func(e sql.Expression) (sql.Expression, error) {
switch e := e.(type) {
case *expression.Or:
if isTrue(e.Left) {
Expand Down
30 changes: 8 additions & 22 deletions sql/analyzer/parallelize.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
return node, nil
}

node, err := node.TransformUp(func(node sql.Node) (sql.Node, error) {
node, err := plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
if !isParallelizable(node) {
return node, nil
}
Expand All @@ -47,7 +47,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
return nil, err
}

return node.TransformUp(removeRedundantExchanges)
return plan.TransformUp(node, removeRedundantExchanges)
}

// removeRedundantExchanges removes all the exchanges except for the topmost
Expand All @@ -58,13 +58,17 @@ func removeRedundantExchanges(node sql.Node) (sql.Node, error) {
return node, nil
}

e := &protectedExchange{exchange}
return e.TransformUp(func(node sql.Node) (sql.Node, error) {
child, err := plan.TransformUp(exchange.Child, func(node sql.Node) (sql.Node, error) {
if exchange, ok := node.(*plan.Exchange); ok {
return exchange.Child, nil
}
return node, nil
})
if err != nil {
return nil, err
}

return exchange.WithChildren(child)
}

func isParallelizable(node sql.Node) bool {
Expand Down Expand Up @@ -103,21 +107,3 @@ func isParallelizable(node sql.Node) bool {

return ok && tableSeen && lastWasTable
}

// protectedExchange is a placeholder node that protects a certain exchange
// node from being removed during transformations.
type protectedExchange struct {
*plan.Exchange
}

// TransformUp transforms the child with the given transform function but it
// will not call the transform function with the new instance. Instead of
// another protectedExchange, it will return an Exchange.
func (e *protectedExchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
child, err := e.Child.TransformUp(f)
if err != nil {
return nil, err
}

return plan.NewExchange(e.Parallelism, child), nil
}
4 changes: 2 additions & 2 deletions sql/analyzer/parallelize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package analyzer
import (
"testing"

"github.com/stretchr/testify/require"
"github.com/src-d/go-mysql-server/mem"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
"github.com/stretchr/testify/require"
)

func TestParallelize(t *testing.T) {
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestRemoveRedundantExchanges(t *testing.T) {
),
)

result, err := node.TransformUp(removeRedundantExchanges)
result, err := plan.TransformUp(node, removeRedundantExchanges)
require.NoError(err)
require.Equal(expected, result)
}
4 changes: 2 additions & 2 deletions sql/analyzer/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
processList := a.Catalog.ProcessList

var seen = make(map[string]struct{})
n, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.ResolvedTable:
switch n.Table.(type) {
Expand Down Expand Up @@ -73,7 +73,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

// Remove QueryProcess nodes from the subqueries. Otherwise, the process
// will be marked as done as soon as a subquery finishes.
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
node, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
if sq, ok := n.(*plan.SubqueryAlias); ok {
if qp, ok := sq.Child.(*plan.QueryProcess); ok {
return plan.NewSubqueryAlias(sq.Name(), qp.Child), nil
Expand Down
11 changes: 5 additions & 6 deletions sql/analyzer/prune_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func pruneSubqueries(
n sql.Node,
parentColumns usedColumns,
) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
subq, ok := n.(*plan.SubqueryAlias)
if !ok {
return n, nil
Expand All @@ -142,7 +142,7 @@ func pruneSubqueries(
}

func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.Project:
return pruneProject(n, columns), nil
Expand All @@ -155,7 +155,7 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
}

func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.SubqueryAlias:
child, err := fixRemainingFieldsIndexes(n.Child)
Expand All @@ -165,8 +165,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {

return plan.NewSubqueryAlias(n.Name(), child), nil
default:
exp, ok := n.(sql.Expressioner)
if !ok {
if _, ok := n.(sql.Expressioner); !ok {
return n, nil
}

Expand All @@ -184,7 +183,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
indexes[tableCol{col.Source, col.Name}] = i
}

return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
gf, ok := e.(*expression.GetField)
if !ok {
return e, nil
Expand Down
Loading