Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 128924a

Browse files
committed
sql: implement new API for node transformation
Instead of having TransformUp and TransformExpressionsUp in each node, which was really error prone, now we have a different API. Each node will have a WithChildren method that will receive the children from which a new node of the same type will be created. These nodes must come in the same number and order as the ones returned by the Children method. Expressioner nodes will also have WithExpressions method, which is the same, except it will create a new node with its expressions changed, instead of the children nodes. The plan package will expose 3 new helpers: - TransformUp: which transforms a node from the bottom-up. - TransformExpressionsUp: which transforms expressions of a node from the bottom up. - TransformExpressions: which transforms the expressions of just the given node. Just like with nodes, expressions will also have a new WithChildren method that does the exact same thing as it does in the nodes. The expression package will expose a new helper: - TransformUp: which transforms an expression from the bottom up. Caveats and limitations: One thing that may seem odd is the limitation that WithChildren and WithExpressions must receive the children in the exact same order and number as they were returned from Expressions or Children. This is because without this limitation there is no way to build certain nodes. If we force this limitation on one, it would feel odd not to have it elsewhere. For example, take Case expression into account. It may or may not have Expr, it has a list of branches (each having 2 expressions) and it may or may not have an Else expression. If WithChildren receives N children, how do we know where to put all these expressions? This limitation allows us to build the node beacause we know the shape of children must match the current shape. Signed-off-by: Miguel Molina <[email protected]>
1 parent bbae519 commit 128924a

File tree

127 files changed

+1107
-1865
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

127 files changed

+1107
-1865
lines changed

mem/table.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import (
77
"io"
88
"strconv"
99

10-
errors "gopkg.in/src-d/go-errors.v1"
1110
"github.com/src-d/go-mysql-server/sql"
1211
"github.com/src-d/go-mysql-server/sql/expression"
12+
errors "gopkg.in/src-d/go-errors.v1"
1313
)
1414

1515
// Table represents an in-memory database table.
@@ -312,14 +312,14 @@ func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression {
312312
var handled []sql.Expression
313313
for _, f := range filters {
314314
var hasOtherFields bool
315-
_, _ = f.TransformUp(func(e sql.Expression) (sql.Expression, error) {
315+
expression.Inspect(f, func(e sql.Expression) bool {
316316
if e, ok := e.(*expression.GetField); ok {
317317
if e.Table() != t.name || !t.schema.Contains(e.Name(), t.name) {
318318
hasOtherFields = true
319+
return false
319320
}
320321
}
321-
322-
return e, nil
322+
return true
323323
})
324324

325325
if !hasOtherFields {

sql/analyzer/aggregations.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func reorderAggregations(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e
1616

1717
a.Log("reorder aggregations, node of type: %T", n)
1818

19-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
19+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
2020
switch n := n.(type) {
2121
case *plan.GroupBy:
2222
if !hasHiddenAggregations(n.Aggregate...) {
@@ -38,7 +38,7 @@ func fixAggregations(projection, grouping []sql.Expression, child sql.Node) (sql
3838

3939
for i, p := range projection {
4040
var transformed bool
41-
e, err := p.TransformUp(func(e sql.Expression) (sql.Expression, error) {
41+
e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) {
4242
agg, ok := e.(sql.Aggregation)
4343
if !ok {
4444
return e, nil

sql/analyzer/assign_catalog.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
1010
span, _ := ctx.Span("assign_catalog")
1111
defer span.Finish()
1212

13-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
13+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
1414
if !n.Resolved() {
1515
return n, nil
1616
}

sql/analyzer/convert_dates.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
1919
// replaced by.
2020
var replacements = make(map[tableCol]string)
2121

22-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
22+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
2323
exp, ok := n.(sql.Expressioner)
2424
if !ok {
2525
return n, nil
@@ -48,7 +48,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
4848
case *plan.GroupBy:
4949
var aggregate = make([]sql.Expression, len(exp.Aggregate))
5050
for i, a := range exp.Aggregate {
51-
agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) {
51+
agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) {
5252
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
5353
})
5454
if err != nil {
@@ -64,7 +64,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
6464

6565
var grouping = make([]sql.Expression, len(exp.Grouping))
6666
for i, g := range exp.Grouping {
67-
gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) {
67+
gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) {
6868
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false)
6969
})
7070
if err != nil {
@@ -77,7 +77,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
7777
case *plan.Project:
7878
var projections = make([]sql.Expression, len(exp.Projections))
7979
for i, e := range exp.Projections {
80-
expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
80+
expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) {
8181
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
8282
})
8383
if err != nil {
@@ -93,7 +93,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
9393

9494
result = plan.NewProject(projections, exp.Child)
9595
default:
96-
result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
96+
result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
9797
return addDateConvert(e, n, replacements, nodeReplacements, expressions, false)
9898
})
9999
}

sql/analyzer/filters.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func exprToTableFilters(expr sql.Expression) filters {
2020
for _, expr := range splitExpression(expr) {
2121
var seenTables = make(map[string]struct{})
2222
var lastTable string
23-
_, _ = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) {
23+
expression.Inspect(expr, func(e sql.Expression) bool {
2424
f, ok := e.(*expression.GetField)
2525
if ok {
2626
if _, ok := seenTables[f.Table()]; !ok {
@@ -29,7 +29,7 @@ func exprToTableFilters(expr sql.Expression) filters {
2929
}
3030
}
3131

32-
return e, nil
32+
return true
3333
})
3434

3535
if len(seenTables) == 1 {

sql/analyzer/optimization_rules.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package analyzer
22

33
import (
4-
"gopkg.in/src-d/go-errors.v1"
54
"github.com/src-d/go-mysql-server/sql"
65
"github.com/src-d/go-mysql-server/sql/expression"
76
"github.com/src-d/go-mysql-server/sql/plan"
7+
"gopkg.in/src-d/go-errors.v1"
88
)
99

1010
func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) {
@@ -17,7 +17,7 @@ func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, er
1717

1818
a.Log("erase projection, node of type: %T", node)
1919

20-
return node.TransformUp(func(node sql.Node) (sql.Node, error) {
20+
return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
2121
project, ok := node.(*plan.Project)
2222
if ok && project.Schema().Equals(project.Child.Schema()) {
2323
a.Log("project erased")
@@ -35,12 +35,13 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e
3535
a.Log("optimize distinct, node of type: %T", node)
3636
if n, ok := node.(*plan.Distinct); ok {
3737
var isSorted bool
38-
_, _ = node.TransformUp(func(node sql.Node) (sql.Node, error) {
38+
plan.Inspect(n, func(node sql.Node) bool {
3939
a.Log("checking for optimization in node of type: %T", node)
4040
if _, ok := node.(*plan.Sort); ok {
4141
isSorted = true
42+
return false
4243
}
43-
return node, nil
44+
return true
4445
})
4546

4647
if isSorted {
@@ -65,7 +66,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
6566
a.Log("reorder projection, node of type: %T", n)
6667

6768
// Then we transform the projection
68-
return n.TransformUp(func(node sql.Node) (sql.Node, error) {
69+
return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) {
6970
project, ok := node.(*plan.Project)
7071
// When we transform the projection, the children will always be
7172
// unresolved in the case we want to fix, as the reorder happens just
@@ -92,7 +93,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
9293

9394
// And add projection nodes where needed in the child tree.
9495
var didNeedReorder bool
95-
child, err := project.Child.TransformUp(func(node sql.Node) (sql.Node, error) {
96+
child, err := plan.TransformUp(project.Child, func(node sql.Node) (sql.Node, error) {
9697
var requiredColumns []string
9798
switch node := node.(type) {
9899
case *plan.Sort, *plan.Filter:
@@ -200,7 +201,7 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.
200201

201202
a.Log("moving join conditions to filter, node of type: %T", n)
202203

203-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
204+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
204205
join, ok := n.(*plan.InnerJoin)
205206
if !ok {
206207
return n, nil
@@ -268,7 +269,7 @@ func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.N
268269

269270
a.Log("removing unnecessary converts, node of type: %T", n)
270271

271-
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
272+
return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) {
272273
if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() {
273274
return c.Child, nil
274275
}
@@ -336,13 +337,13 @@ func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
336337

337338
a.Log("evaluating filters, node of type: %T", node)
338339

339-
return node.TransformUp(func(node sql.Node) (sql.Node, error) {
340+
return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
340341
filter, ok := node.(*plan.Filter)
341342
if !ok {
342343
return node, nil
343344
}
344345

345-
e, err := filter.Expression.TransformUp(func(e sql.Expression) (sql.Expression, error) {
346+
e, err := expression.TransformUp(filter.Expression, func(e sql.Expression) (sql.Expression, error) {
346347
switch e := e.(type) {
347348
case *expression.Or:
348349
if isTrue(e.Left) {

sql/analyzer/parallelize.go

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
3434
return node, nil
3535
}
3636

37-
node, err := node.TransformUp(func(node sql.Node) (sql.Node, error) {
37+
node, err := plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
3838
if !isParallelizable(node) {
3939
return node, nil
4040
}
@@ -47,7 +47,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
4747
return nil, err
4848
}
4949

50-
return node.TransformUp(removeRedundantExchanges)
50+
return plan.TransformUp(node, removeRedundantExchanges)
5151
}
5252

5353
// removeRedundantExchanges removes all the exchanges except for the topmost
@@ -58,13 +58,17 @@ func removeRedundantExchanges(node sql.Node) (sql.Node, error) {
5858
return node, nil
5959
}
6060

61-
e := &protectedExchange{exchange}
62-
return e.TransformUp(func(node sql.Node) (sql.Node, error) {
61+
child, err := plan.TransformUp(exchange.Child, func(node sql.Node) (sql.Node, error) {
6362
if exchange, ok := node.(*plan.Exchange); ok {
6463
return exchange.Child, nil
6564
}
6665
return node, nil
6766
})
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
return exchange.WithChildren(child)
6872
}
6973

7074
func isParallelizable(node sql.Node) bool {
@@ -103,21 +107,3 @@ func isParallelizable(node sql.Node) bool {
103107

104108
return ok && tableSeen && lastWasTable
105109
}
106-
107-
// protectedExchange is a placeholder node that protects a certain exchange
108-
// node from being removed during transformations.
109-
type protectedExchange struct {
110-
*plan.Exchange
111-
}
112-
113-
// TransformUp transforms the child with the given transform function but it
114-
// will not call the transform function with the new instance. Instead of
115-
// another protectedExchange, it will return an Exchange.
116-
func (e *protectedExchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
117-
child, err := e.Child.TransformUp(f)
118-
if err != nil {
119-
return nil, err
120-
}
121-
122-
return plan.NewExchange(e.Parallelism, child), nil
123-
}

sql/analyzer/parallelize_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package analyzer
33
import (
44
"testing"
55

6-
"github.com/stretchr/testify/require"
76
"github.com/src-d/go-mysql-server/mem"
87
"github.com/src-d/go-mysql-server/sql"
98
"github.com/src-d/go-mysql-server/sql/expression"
109
"github.com/src-d/go-mysql-server/sql/plan"
10+
"github.com/stretchr/testify/require"
1111
)
1212

1313
func TestParallelize(t *testing.T) {
@@ -222,7 +222,7 @@ func TestRemoveRedundantExchanges(t *testing.T) {
222222
),
223223
)
224224

225-
result, err := node.TransformUp(removeRedundantExchanges)
225+
result, err := plan.TransformUp(node, removeRedundantExchanges)
226226
require.NoError(err)
227227
require.Equal(expected, result)
228228
}

sql/analyzer/process.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
1919
processList := a.Catalog.ProcessList
2020

2121
var seen = make(map[string]struct{})
22-
n, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
22+
n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
2323
switch n := n.(type) {
2424
case *plan.ResolvedTable:
2525
switch n.Table.(type) {
@@ -73,7 +73,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
7373

7474
// Remove QueryProcess nodes from the subqueries. Otherwise, the process
7575
// will be marked as done as soon as a subquery finishes.
76-
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
76+
node, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
7777
if sq, ok := n.(*plan.SubqueryAlias); ok {
7878
if qp, ok := sq.Child.(*plan.QueryProcess); ok {
7979
return plan.NewSubqueryAlias(sq.Name(), qp.Child), nil

sql/analyzer/prune_columns.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func pruneSubqueries(
131131
n sql.Node,
132132
parentColumns usedColumns,
133133
) (sql.Node, error) {
134-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
134+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
135135
subq, ok := n.(*plan.SubqueryAlias)
136136
if !ok {
137137
return n, nil
@@ -142,7 +142,7 @@ func pruneSubqueries(
142142
}
143143

144144
func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
145-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
145+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
146146
switch n := n.(type) {
147147
case *plan.Project:
148148
return pruneProject(n, columns), nil
@@ -155,7 +155,7 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
155155
}
156156

157157
func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
158-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
158+
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
159159
switch n := n.(type) {
160160
case *plan.SubqueryAlias:
161161
child, err := fixRemainingFieldsIndexes(n.Child)
@@ -165,8 +165,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
165165

166166
return plan.NewSubqueryAlias(n.Name(), child), nil
167167
default:
168-
exp, ok := n.(sql.Expressioner)
169-
if !ok {
168+
if _, ok := n.(sql.Expressioner); !ok {
170169
return n, nil
171170
}
172171

@@ -184,7 +183,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
184183
indexes[tableCol{col.Source, col.Name}] = i
185184
}
186185

187-
return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
186+
return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
188187
gf, ok := e.(*expression.GetField)
189188
if !ok {
190189
return e, nil

0 commit comments

Comments
 (0)