diff --git a/mem/table.go b/mem/table.go index 8217a1211..d12a1e316 100644 --- a/mem/table.go +++ b/mem/table.go @@ -7,9 +7,9 @@ import ( "io" "strconv" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" ) // Table represents an in-memory database table. @@ -312,14 +312,14 @@ func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression { var handled []sql.Expression for _, f := range filters { var hasOtherFields bool - _, _ = f.TransformUp(func(e sql.Expression) (sql.Expression, error) { + expression.Inspect(f, func(e sql.Expression) bool { if e, ok := e.(*expression.GetField); ok { if e.Table() != t.name || !t.schema.Contains(e.Name(), t.name) { hasOtherFields = true + return false } } - - return e, nil + return true }) if !hasOtherFields { diff --git a/sql/analyzer/aggregations.go b/sql/analyzer/aggregations.go index 90e5baab9..9afa5b49f 100644 --- a/sql/analyzer/aggregations.go +++ b/sql/analyzer/aggregations.go @@ -16,7 +16,7 @@ func reorderAggregations(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e a.Log("reorder aggregations, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.GroupBy: if !hasHiddenAggregations(n.Aggregate...) { @@ -38,7 +38,7 @@ func fixAggregations(projection, grouping []sql.Expression, child sql.Node) (sql for i, p := range projection { var transformed bool - e, err := p.TransformUp(func(e sql.Expression) (sql.Expression, error) { + e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) { agg, ok := e.(sql.Aggregation) if !ok { return e, nil diff --git a/sql/analyzer/assign_catalog.go b/sql/analyzer/assign_catalog.go index 409ae082e..0a8f76ee8 100644 --- a/sql/analyzer/assign_catalog.go +++ b/sql/analyzer/assign_catalog.go @@ -10,7 +10,7 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) span, _ := ctx.Span("assign_catalog") defer span.Finish() - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { if !n.Resolved() { return n, nil } diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go index 1397d867a..c7af55431 100644 --- a/sql/analyzer/convert_dates.go +++ b/sql/analyzer/convert_dates.go @@ -19,7 +19,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { // replaced by. var replacements = make(map[tableCol]string) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { exp, ok := n.(sql.Expressioner) if !ok { return n, nil @@ -48,7 +48,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { case *plan.GroupBy: var aggregate = make([]sql.Expression, len(exp.Aggregate)) for i, a := range exp.Aggregate { - agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) { + agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) }) if err != nil { @@ -64,7 +64,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { var grouping = make([]sql.Expression, len(exp.Grouping)) for i, g := range exp.Grouping { - gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) { + gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false) }) if err != nil { @@ -77,7 +77,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { case *plan.Project: var projections = make([]sql.Expression, len(exp.Projections)) for i, e := range exp.Projections { - expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) }) if err != nil { @@ -93,7 +93,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { result = plan.NewProject(projections, exp.Child) default: - result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, n, replacements, nodeReplacements, expressions, false) }) } diff --git a/sql/analyzer/filters.go b/sql/analyzer/filters.go index 9cad5207f..bbe54adc7 100644 --- a/sql/analyzer/filters.go +++ b/sql/analyzer/filters.go @@ -20,7 +20,7 @@ func exprToTableFilters(expr sql.Expression) filters { for _, expr := range splitExpression(expr) { var seenTables = make(map[string]struct{}) var lastTable string - _, _ = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + expression.Inspect(expr, func(e sql.Expression) bool { f, ok := e.(*expression.GetField) if ok { if _, ok := seenTables[f.Table()]; !ok { @@ -29,7 +29,7 @@ func exprToTableFilters(expr sql.Expression) filters { } } - return e, nil + return true }) if len(seenTables) == 1 { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index bfa83aeaa..48f70f159 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -1,10 +1,10 @@ package analyzer import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { @@ -17,7 +17,7 @@ func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, er a.Log("erase projection, node of type: %T", node) - return node.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { project, ok := node.(*plan.Project) if ok && project.Schema().Equals(project.Child.Schema()) { a.Log("project erased") @@ -35,12 +35,13 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e a.Log("optimize distinct, node of type: %T", node) if n, ok := node.(*plan.Distinct); ok { var isSorted bool - _, _ = node.TransformUp(func(node sql.Node) (sql.Node, error) { + plan.Inspect(n, func(node sql.Node) bool { a.Log("checking for optimization in node of type: %T", node) if _, ok := node.(*plan.Sort); ok { isSorted = true + return false } - return node, nil + return true }) if isSorted { @@ -65,7 +66,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err a.Log("reorder projection, node of type: %T", n) // Then we transform the projection - return n.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { project, ok := node.(*plan.Project) // When we transform the projection, the children will always be // unresolved in the case we want to fix, as the reorder happens just @@ -92,7 +93,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err // And add projection nodes where needed in the child tree. var didNeedReorder bool - child, err := project.Child.TransformUp(func(node sql.Node) (sql.Node, error) { + child, err := plan.TransformUp(project.Child, func(node sql.Node) (sql.Node, error) { var requiredColumns []string switch node := node.(type) { case *plan.Sort, *plan.Filter: @@ -200,7 +201,7 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql. a.Log("moving join conditions to filter, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { join, ok := n.(*plan.InnerJoin) if !ok { return n, nil @@ -268,7 +269,7 @@ func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.N a.Log("removing unnecessary converts, node of type: %T", n) - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() { return c.Child, nil } @@ -336,13 +337,13 @@ func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) a.Log("evaluating filters, node of type: %T", node) - return node.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { filter, ok := node.(*plan.Filter) if !ok { return node, nil } - e, err := filter.Expression.TransformUp(func(e sql.Expression) (sql.Expression, error) { + e, err := expression.TransformUp(filter.Expression, func(e sql.Expression) (sql.Expression, error) { switch e := e.(type) { case *expression.Or: if isTrue(e.Left) { diff --git a/sql/analyzer/parallelize.go b/sql/analyzer/parallelize.go index 9af592339..a56b9479e 100644 --- a/sql/analyzer/parallelize.go +++ b/sql/analyzer/parallelize.go @@ -34,7 +34,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) return node, nil } - node, err := node.TransformUp(func(node sql.Node) (sql.Node, error) { + node, err := plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { if !isParallelizable(node) { return node, nil } @@ -47,7 +47,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) return nil, err } - return node.TransformUp(removeRedundantExchanges) + return plan.TransformUp(node, removeRedundantExchanges) } // removeRedundantExchanges removes all the exchanges except for the topmost @@ -58,13 +58,17 @@ func removeRedundantExchanges(node sql.Node) (sql.Node, error) { return node, nil } - e := &protectedExchange{exchange} - return e.TransformUp(func(node sql.Node) (sql.Node, error) { + child, err := plan.TransformUp(exchange.Child, func(node sql.Node) (sql.Node, error) { if exchange, ok := node.(*plan.Exchange); ok { return exchange.Child, nil } return node, nil }) + if err != nil { + return nil, err + } + + return exchange.WithChildren(child) } func isParallelizable(node sql.Node) bool { @@ -103,21 +107,3 @@ func isParallelizable(node sql.Node) bool { return ok && tableSeen && lastWasTable } - -// protectedExchange is a placeholder node that protects a certain exchange -// node from being removed during transformations. -type protectedExchange struct { - *plan.Exchange -} - -// TransformUp transforms the child with the given transform function but it -// will not call the transform function with the new instance. Instead of -// another protectedExchange, it will return an Exchange. -func (e *protectedExchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err - } - - return plan.NewExchange(e.Parallelism, child), nil -} diff --git a/sql/analyzer/parallelize_test.go b/sql/analyzer/parallelize_test.go index 7af884aff..41282c03e 100644 --- a/sql/analyzer/parallelize_test.go +++ b/sql/analyzer/parallelize_test.go @@ -3,11 +3,11 @@ package analyzer import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" ) func TestParallelize(t *testing.T) { @@ -222,7 +222,7 @@ func TestRemoveRedundantExchanges(t *testing.T) { ), ) - result, err := node.TransformUp(removeRedundantExchanges) + result, err := plan.TransformUp(node, removeRedundantExchanges) require.NoError(err) require.Equal(expected, result) } diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index 2d2d0f3f1..926d52c83 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -19,7 +19,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { processList := a.Catalog.ProcessList var seen = make(map[string]struct{}) - n, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { + n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.ResolvedTable: switch n.Table.(type) { @@ -73,7 +73,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { // Remove QueryProcess nodes from the subqueries. Otherwise, the process // will be marked as done as soon as a subquery finishes. - node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { + node, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { if sq, ok := n.(*plan.SubqueryAlias); ok { if qp, ok := sq.Child.(*plan.QueryProcess); ok { return plan.NewSubqueryAlias(sq.Name(), qp.Child), nil diff --git a/sql/analyzer/prune_columns.go b/sql/analyzer/prune_columns.go index db424de9f..0b4121ebe 100644 --- a/sql/analyzer/prune_columns.go +++ b/sql/analyzer/prune_columns.go @@ -131,7 +131,7 @@ func pruneSubqueries( n sql.Node, parentColumns usedColumns, ) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { subq, ok := n.(*plan.SubqueryAlias) if !ok { return n, nil @@ -142,7 +142,7 @@ func pruneSubqueries( } func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.Project: return pruneProject(n, columns), nil @@ -155,7 +155,7 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) { } func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.SubqueryAlias: child, err := fixRemainingFieldsIndexes(n.Child) @@ -165,8 +165,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { return plan.NewSubqueryAlias(n.Name(), child), nil default: - exp, ok := n.(sql.Expressioner) - if !ok { + if _, ok := n.(sql.Expressioner); !ok { return n, nil } @@ -184,7 +183,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { indexes[tableCol{col.Source, col.Name}] = i } - return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { gf, ok := e.(*expression.GetField) if !ok { return e, nil diff --git a/sql/analyzer/pushdown.go b/sql/analyzer/pushdown.go index 55321f54d..4c500ffcb 100644 --- a/sql/analyzer/pushdown.go +++ b/sql/analyzer/pushdown.go @@ -66,7 +66,7 @@ func fixFieldIndexesOnExpressions(schema sql.Schema, expressions ...sql.Expressi // for GetField expressions according to the schema of the row in the table // and not the one where the filter came from. func fixFieldIndexes(schema sql.Schema, exp sql.Expression) (sql.Expression, error) { - return exp.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return expression.TransformUp(exp, func(e sql.Expression) (sql.Expression, error) { switch e := e.(type) { case *expression.GetField: // we need to rewrite the indexes for the table row @@ -134,7 +134,7 @@ func transformPushdown( var handledFilters []sql.Expression var queryIndexes []sql.Index - node, err := n.TransformUp(func(node sql.Node) (sql.Node, error) { + node, err := plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", node) switch node := node.(type) { case *plan.Filter: @@ -173,8 +173,7 @@ func transformPushdown( } func transformExpressioners(node sql.Node) (sql.Node, error) { - expressioner, ok := node.(sql.Expressioner) - if !ok { + if _, ok := node.(sql.Expressioner); !ok { return node, nil } @@ -187,7 +186,7 @@ func transformExpressioners(node sql.Node) (sql.Node, error) { return node, nil } - n, err := expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + n, err := plan.TransformExpressions(node, func(e sql.Expression) (sql.Expression, error) { for _, schema := range schemas { fixed, err := fixFieldIndexes(schema, e) if err == nil { @@ -338,20 +337,11 @@ func (r *releaser) Schema() sql.Schema { return r.Child.Schema() } -func (r *releaser) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := r.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(&releaser{child, r.Release}) -} - -func (r *releaser) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := r.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +func (r *releaser) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) } - return &releaser{child, r.Release}, nil + return &releaser{children[0], r.Release}, nil } func (r *releaser) String() string { diff --git a/sql/analyzer/pushdown_test.go b/sql/analyzer/pushdown_test.go index b30a54630..d0cb737ad 100644 --- a/sql/analyzer/pushdown_test.go +++ b/sql/analyzer/pushdown_test.go @@ -3,11 +3,11 @@ package analyzer import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" ) func TestPushdownProjectionAndFilters(t *testing.T) { @@ -212,7 +212,7 @@ func TestPushdownIndexable(t *testing.T) { require.NoError(err) // we need to remove the release function to compare, otherwise it will fail - result, err = result.TransformUp(func(node sql.Node) (sql.Node, error) { + result, err = plan.TransformUp(result, func(node sql.Node) (sql.Node, error) { switch node := node.(type) { case *releaser: return &releaser{Child: node.Child}, nil diff --git a/sql/analyzer/resolve_columns.go b/sql/analyzer/resolve_columns.go index 5c4c33769..79e976739 100644 --- a/sql/analyzer/resolve_columns.go +++ b/sql/analyzer/resolve_columns.go @@ -5,11 +5,11 @@ import ( "sort" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/similartext" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" "vitess.io/vitess/go/vt/sqlparser" ) @@ -98,8 +98,15 @@ func (deferredColumn) IsNullable() bool { return true } -func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(e) +// Children implements the Expression interface. +func (deferredColumn) Children() []sql.Expression { return nil } + +// WithChildren implements the Expression interface. +func (e deferredColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 0) + } + return e, nil } type tableCol struct { @@ -120,16 +127,15 @@ type column interface { } func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - exp, ok := n.(sql.Expressioner) - if !ok || n.Resolved() { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + if _, ok := n.(sql.Expressioner); !ok || n.Resolved() { return n, nil } columns := getNodeAvailableColumns(n) tables := getNodeAvailableTables(n) - return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { return qualifyExpression(e, columns, tables) }) }) @@ -198,7 +204,7 @@ func qualifyExpression( default: // If any other kind of expression has a star, just replace it // with an unqualified star because it cannot be expanded. - return e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { if _, ok := e.(*expression.Star); ok { return expression.NewStar(), nil } @@ -289,14 +295,13 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) defer span.Finish() a.Log("resolve columns, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil } - expressioner, ok := n.(sql.Expressioner) - if !ok { + if _, ok := n.(sql.Expressioner); !ok { return n, nil } @@ -308,7 +313,7 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) } columns := findChildIndexedColumns(n) - return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { a.Log("transforming expression of type: %T", e) uc, ok := e.(column) @@ -394,7 +399,7 @@ func resolveGroupingColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node return n, nil } - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { g, ok := n.(*plan.GroupBy) if n.Resolved() || !ok || len(g.Grouping) == 0 { return n, nil @@ -510,7 +515,7 @@ func resolveGroupingColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node if len(renames) > 0 { for i, expr := range newAggregate { var err error - newAggregate[i], err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + newAggregate[i], err = expression.TransformUp(expr, func(e sql.Expression) (sql.Expression, error) { col, ok := e.(*expression.UnresolvedColumn) if ok { // We need to make sure we don't rename the reference to the diff --git a/sql/analyzer/resolve_database.go b/sql/analyzer/resolve_database.go index 6590e0748..2c5d0f628 100644 --- a/sql/analyzer/resolve_database.go +++ b/sql/analyzer/resolve_database.go @@ -2,6 +2,7 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" ) func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -10,7 +11,7 @@ func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error a.Log("resolve database, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { d, ok := n.(sql.Databaser) if !ok { return n, nil diff --git a/sql/analyzer/resolve_functions.go b/sql/analyzer/resolve_functions.go index a1a200a87..34bac11b0 100644 --- a/sql/analyzer/resolve_functions.go +++ b/sql/analyzer/resolve_functions.go @@ -3,6 +3,7 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" ) func resolveFunctions(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -10,13 +11,13 @@ func resolveFunctions(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, erro defer span.Finish() a.Log("resolve functions, node of type %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil } - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { a.Log("transforming expression of type: %T", e) if e.Resolved() { return e, nil diff --git a/sql/analyzer/resolve_generators.go b/sql/analyzer/resolve_generators.go index 437cf332a..4635e24d6 100644 --- a/sql/analyzer/resolve_generators.go +++ b/sql/analyzer/resolve_generators.go @@ -1,11 +1,11 @@ package analyzer import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/expression/function" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) var ( @@ -14,7 +14,7 @@ var ( ) func resolveGenerators(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { p, ok := n.(*plan.Project) if !ok { return n, nil diff --git a/sql/analyzer/resolve_having.go b/sql/analyzer/resolve_having.go index eee00bfbb..de520d2bc 100644 --- a/sql/analyzer/resolve_having.go +++ b/sql/analyzer/resolve_having.go @@ -12,7 +12,7 @@ import ( ) func resolveHaving(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { - return node.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { having, ok := node.(*plan.Having) if !ok { return node, nil @@ -163,44 +163,19 @@ func addColumnsToGroupBy(node sql.Node, columns []sql.Expression) (sql.Node, err } return plan.NewProject(append(node.Projections, newProjections...), child), nil - case *plan.Filter: - child, err := addColumnsToGroupBy(node.Child, columns) - if err != nil { - return nil, err - } - return plan.NewFilter(node.Expression, child), nil - case *plan.Sort: - child, err := addColumnsToGroupBy(node.Child, columns) + case *plan.Filter, + *plan.Sort, + *plan.Limit, + *plan.Offset, + *plan.Distinct, + *plan.Having: + child, err := addColumnsToGroupBy(node.Children()[0], columns) if err != nil { return nil, err } - return plan.NewSort(node.SortFields, child), nil - case *plan.Limit: - child, err := addColumnsToGroupBy(node.Child, columns) - if err != nil { - return nil, err - } - return plan.NewLimit(node.Limit, child), nil - case *plan.Offset: - child, err := addColumnsToGroupBy(node.Child, columns) - if err != nil { - return nil, err - } - return plan.NewOffset(node.Offset, child), nil - case *plan.Distinct: - child, err := addColumnsToGroupBy(node.Child, columns) - if err != nil { - return nil, err - } - return plan.NewDistinct(child), nil + return node.WithChildren(child) case *plan.GroupBy: return plan.NewGroupBy(append(node.Aggregate, columns...), node.Grouping, node.Child), nil - case *plan.Having: - child, err := addColumnsToGroupBy(node.Child, columns) - if err != nil { - return nil, err - } - return plan.NewHaving(node.Cond, child), nil default: return nil, errHavingNeedsGroupBy.New() } @@ -318,7 +293,7 @@ func replaceAggregations(having *plan.Having) (*plan.Having, bool, error) { // indexes after they have been pushed up. This is because some of these // may have already been projected in some projection and we cannot ensure // from here what the final index will be. - cond, err := having.Cond.TransformUp(func(e sql.Expression) (sql.Expression, error) { + cond, err := expression.TransformUp(having.Cond, func(e sql.Expression) (sql.Expression, error) { agg, ok := e.(sql.Aggregation) if !ok { return e, nil @@ -372,7 +347,7 @@ func replaceAggregations(having *plan.Having) (*plan.Having, bool, error) { // Now, the tokens are replaced with the actual columns, now that we know // what the indexes are. - cond, err = having.Cond.TransformUp(func(e sql.Expression) (sql.Expression, error) { + cond, err = expression.TransformUp(having.Cond, func(e sql.Expression) (sql.Expression, error) { f, ok := e.(*expression.GetField) if !ok { return e, nil @@ -474,7 +449,7 @@ func aggregationChildEquals(a, b sql.Expression) bool { return true }) - a, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) { + a, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) { var table, name string switch e := e.(type) { case *expression.UnresolvedColumn: diff --git a/sql/analyzer/resolve_natural_joins.go b/sql/analyzer/resolve_natural_joins.go index 0ada6eb36..a6cf1fdb9 100644 --- a/sql/analyzer/resolve_natural_joins.go +++ b/sql/analyzer/resolve_natural_joins.go @@ -15,8 +15,8 @@ func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e var replacements = make(map[tableCol]tableCol) var tableAliases = make(map[string]string) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - switch n := n.(type) { + return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { + switch n := node.(type) { case *plan.TableAlias: alias := n.Name() table := n.Child.(*plan.ResolvedTable).Name() @@ -25,7 +25,7 @@ func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e case *plan.NaturalJoin: return resolveNaturalJoin(n, replacements) case sql.Expressioner: - return replaceExpressions(n, replacements, tableAliases) + return replaceExpressions(node, replacements, tableAliases) default: return n, nil } @@ -115,11 +115,11 @@ func findCol(s sql.Schema, name string) (int, *sql.Column) { } func replaceExpressions( - n sql.Expressioner, + n sql.Node, replacements map[tableCol]tableCol, tableAliases map[string]string, ) (sql.Node, error) { - return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { switch e := e.(type) { case *expression.GetField, *expression.UnresolvedColumn: var tableName = e.(sql.Tableable).Table() diff --git a/sql/analyzer/resolve_orderby.go b/sql/analyzer/resolve_orderby.go index dc24e674a..706ace606 100644 --- a/sql/analyzer/resolve_orderby.go +++ b/sql/analyzer/resolve_orderby.go @@ -3,10 +3,10 @@ package analyzer import ( "strings" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" ) func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -14,7 +14,7 @@ func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) defer span.Finish() a.Log("resolving order bys, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) sort, ok := n.(*plan.Sort) if !ok { @@ -175,7 +175,7 @@ func pushSortDown(sort *plan.Sort) (sql.Node, error) { func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { a.Log("resolve order by literals") - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { sort, ok := n.(*plan.Sort) if !ok { return n, nil diff --git a/sql/analyzer/resolve_stars.go b/sql/analyzer/resolve_stars.go index 9f8e8534f..2261752f3 100644 --- a/sql/analyzer/resolve_stars.go +++ b/sql/analyzer/resolve_stars.go @@ -11,7 +11,7 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { defer span.Finish() a.Log("resolving star, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index 6c315b50b..5015253c1 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -10,7 +10,7 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err defer span.Finish() a.Log("resolving subqueries") - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.SubqueryAlias: a.Log("found subquery %q with child of type %T", n.Name(), n.Child) diff --git a/sql/analyzer/resolve_tables.go b/sql/analyzer/resolve_tables.go index afe00d47f..7632d1a48 100644 --- a/sql/analyzer/resolve_tables.go +++ b/sql/analyzer/resolve_tables.go @@ -21,7 +21,7 @@ func resolveTables(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) defer span.Finish() a.Log("resolve table, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 7bb83d1bb..2d87614b1 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -676,17 +676,12 @@ func TestValidateExplodeUsage(t *testing.T) { type dummyNode struct{ resolved bool } -func (n dummyNode) String() string { return "dummynode" } -func (n dummyNode) Resolved() bool { return n.resolved } -func (dummyNode) Schema() sql.Schema { return nil } -func (dummyNode) Children() []sql.Node { return nil } -func (dummyNode) RowIter(*sql.Context) (sql.RowIter, error) { return nil, nil } -func (dummyNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { return nil, nil } -func (dummyNode) TransformExpressionsUp( - sql.TransformExprFunc, -) (sql.Node, error) { - return nil, nil -} +func (n dummyNode) String() string { return "dummynode" } +func (n dummyNode) Resolved() bool { return n.resolved } +func (dummyNode) Schema() sql.Schema { return nil } +func (dummyNode) Children() []sql.Node { return nil } +func (dummyNode) RowIter(*sql.Context) (sql.RowIter, error) { return nil, nil } +func (dummyNode) WithChildren(...sql.Node) (sql.Node, error) { return nil, nil } func getValidationRule(name string) Rule { for _, rule := range DefaultValidationRules { diff --git a/sql/core.go b/sql/core.go index 14c916ef3..d676372ec 100644 --- a/sql/core.go +++ b/sql/core.go @@ -25,6 +25,10 @@ var ( //ErrUnexpectedRowLength is thrown when the obtained row has more columns than the schema ErrUnexpectedRowLength = errors.NewKind("expected %d values, got %d") + + // ErrInvalidChildrenNumber is returned when the WithChildren method of a + // node or expression is called with an invalid number of arguments. + ErrInvalidChildrenNumber = errors.NewKind("%T: invalid children number, got %d, expected %d") ) // Nameable is something that has a name. @@ -45,17 +49,6 @@ type Resolvable interface { Resolved() bool } -// Transformable is a node which can be transformed. -type Transformable interface { - // TransformUp transforms all nodes and returns the result of this transformation. - // Transformation is not propagated to subqueries. - TransformUp(TransformNodeFunc) (Node, error) - // TransformExpressionsUp transforms all expressions inside the node and all its - // children and returns a node with the result of the transformations. - // Transformation is not propagated to subqueries. - TransformExpressionsUp(TransformExprFunc) (Node, error) -} - // TransformNodeFunc is a function that given a node will return that node // as is or transformed along with an error, if any. type TransformNodeFunc func(Node) (Node, error) @@ -74,11 +67,13 @@ type Expression interface { IsNullable() bool // Eval evaluates the given row and returns a result. Eval(*Context, Row) (interface{}, error) - // TransformUp transforms the expression and all its children with the - // given transform function. - TransformUp(TransformExprFunc) (Expression, error) // Children returns the children expressions of this expression. Children() []Expression + // WithChildren returns a copy of the expression with children replaced. + // It will return an error if the number of children is different than + // the current number of children. They must be given in the same order + // as they are returned by Children. + WithChildren(...Expression) (Expression, error) } // Aggregation implements an aggregation expression, where an @@ -100,7 +95,6 @@ type Aggregation interface { // Node is a node in the execution plan tree. type Node interface { Resolvable - Transformable fmt.Stringer // Schema of the node. Schema() Schema @@ -108,16 +102,30 @@ type Node interface { Children() []Node // RowIter produces a row iterator from this node. RowIter(*Context) (RowIter, error) + // WithChildren returns a copy of the node with children replaced. + // It will return an error if the number of children is different than + // the current number of children. They must be given in the same order + // as they are returned by Children. + WithChildren(...Node) (Node, error) +} + +// OpaqueNode is a node that doesn't allow transformations to its children and +// acts a a black box. +type OpaqueNode interface { + Node + // Opaque reports whether the node is opaque or not. + Opaque() bool } // Expressioner is a node that contains expressions. type Expressioner interface { // Expressions returns the list of expressions contained by the node. Expressions() []Expression - // TransformExpressions applies for each expression in this node - // the expression's TransformUp method with the given function, and - // return a new node with the transformed expressions. - TransformExpressions(TransformExprFunc) (Node, error) + // WithExpressions returns a copy of the node with expressions replaced. + // It will return an error if the number of expressions is different than + // the current number of expressions. They must be given in the same order + // as they are returned by Expressions. + WithExpressions(...Expression) (Node, error) } // Databaser is a node that contains a reference to a database. diff --git a/sql/expression/alias.go b/sql/expression/alias.go index 6b9b1edd4..c7485dfd9 100644 --- a/sql/expression/alias.go +++ b/sql/expression/alias.go @@ -31,13 +31,12 @@ func (e *Alias) String() string { return fmt.Sprintf("%s as %s", e.Child, e.name) } -// TransformUp implements the Expression interface. -func (e *Alias) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Alias) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewAlias(child, e.name)) + return NewAlias(children[0], e.name), nil } // Name implements the Nameable interface. diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index ecdf26961..d7044e06e 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -137,19 +137,12 @@ func isInterval(expr sql.Expression) bool { return ok } -// TransformUp implements the Expression interface. -func (a *Arithmetic) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - l, err := a.Left.TransformUp(f) - if err != nil { - return nil, err - } - - r, err := a.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *Arithmetic) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) } - - return f(NewArithmetic(l, r, a.Op)) + return NewArithmetic(children[0], children[1], a.Op), nil } // Eval implements the Expression interface. @@ -549,12 +542,10 @@ func (e *UnaryMinus) String() string { return fmt.Sprintf("-%s", e.Child) } -// TransformUp implements the sql.Expression interface. -func (e *UnaryMinus) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - c, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *UnaryMinus) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - - return f(NewUnaryMinus(c)) + return NewUnaryMinus(children[0]), nil } diff --git a/sql/expression/between.go b/sql/expression/between.go index ce892158a..15114890b 100644 --- a/sql/expression/between.go +++ b/sql/expression/between.go @@ -98,22 +98,10 @@ func (b *Between) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return cmpLower >= 0 && cmpUpper <= 0, nil } -// TransformUp implements the Expression interface. -func (b *Between) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - val, err := b.Val.TransformUp(f) - if err != nil { - return nil, err - } - - lower, err := b.Lower.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (b *Between) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 3) } - - upper, err := b.Upper.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewBetween(val, lower, upper)) + return NewBetween(children[0], children[1], children[2]), nil } diff --git a/sql/expression/boolean.go b/sql/expression/boolean.go index c4991029a..73815fb7d 100644 --- a/sql/expression/boolean.go +++ b/sql/expression/boolean.go @@ -52,11 +52,10 @@ func (e *Not) String() string { return fmt.Sprintf("NOT(%s)", e.Child) } -// TransformUp implements the Expression interface. -func (e *Not) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Not) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewNot(child)) + return NewNot(children[0]), nil } diff --git a/sql/expression/case.go b/sql/expression/case.go index 28eef6d03..31f037fb3 100644 --- a/sql/expression/case.go +++ b/sql/expression/case.go @@ -130,44 +130,41 @@ func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } -// TransformUp implements the sql.Expression interface. -func (c *Case) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var expr sql.Expression - var err error - +// WithChildren implements the Expression interface. +func (c *Case) WithChildren(children ...sql.Expression) (sql.Expression, error) { + var expected = len(c.Branches) * 2 if c.Expr != nil { - expr, err = c.Expr.TransformUp(f) - if err != nil { - return nil, err - } + expected++ } - var branches []CaseBranch - for _, b := range c.Branches { - var nb CaseBranch - - nb.Cond, err = b.Cond.TransformUp(f) - if err != nil { - return nil, err - } + if c.Else != nil { + expected++ + } - nb.Value, err = b.Value.TransformUp(f) - if err != nil { - return nil, err - } + if len(children) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), expected) + } - branches = append(branches, nb) + var expr, elseExpr sql.Expression + if c.Expr != nil { + expr = children[0] + children = children[1:] } - var elseExpr sql.Expression if c.Else != nil { - elseExpr, err = c.Else.TransformUp(f) - if err != nil { - return nil, err - } + elseExpr = children[len(children)-1] + children = children[:len(children)-1] + } + + var branches []CaseBranch + for i := 0; i < len(children); i += 2 { + branches = append(branches, CaseBranch{ + Cond: children[i], + Value: children[i+1], + }) } - return f(NewCase(expr, branches, elseExpr)) + return NewCase(expr, branches, elseExpr), nil } func (c *Case) String() string { diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index c1fd93e04..8491f9cef 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -4,9 +4,9 @@ import ( "fmt" "sync" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/regex" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // Comparer implements a comparison expression. @@ -157,19 +157,12 @@ func (e *Equals) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result == 0, nil } -// TransformUp implements the Expression interface. -func (e *Equals) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := e.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := e.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Equals) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 2) } - - return f(NewEquals(left, right)) + return NewEquals(children[0], children[1]), nil } func (e *Equals) String() string { @@ -278,19 +271,12 @@ func (re *Regexp) compareRegexp(ctx *sql.Context, row sql.Row) (interface{}, err return ok, nil } -// TransformUp implements the Expression interface. -func (re *Regexp) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := re.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := re.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (re *Regexp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(re, len(children), 2) } - - return f(NewRegexp(left, right)) + return NewRegexp(children[0], children[1]), nil } func (re *Regexp) String() string { @@ -321,19 +307,12 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -// TransformUp implements the Expression interface. -func (gt *GreaterThan) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := gt.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := gt.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(gt, len(children), 2) } - - return f(NewGreaterThan(left, right)) + return NewGreaterThan(children[0], children[1]), nil } func (gt *GreaterThan) String() string { @@ -364,19 +343,12 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result == -1, nil } -// TransformUp implements the Expression interface. -func (lt *LessThan) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := lt.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := lt.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (lt *LessThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(lt, len(children), 2) } - - return f(NewLessThan(left, right)) + return NewLessThan(children[0], children[1]), nil } func (lt *LessThan) String() string { @@ -408,19 +380,12 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, return result > -1, nil } -// TransformUp implements the Expression interface. -func (gte *GreaterThanOrEqual) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := gte.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := gte.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (gte *GreaterThanOrEqual) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(gte, len(children), 2) } - - return f(NewGreaterThanOrEqual(left, right)) + return NewGreaterThanOrEqual(children[0], children[1]), nil } func (gte *GreaterThanOrEqual) String() string { @@ -452,19 +417,12 @@ func (lte *LessThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, er return result < 1, nil } -// TransformUp implements the Expression interface. -func (lte *LessThanOrEqual) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := lte.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (lte *LessThanOrEqual) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(lte, len(children), 2) } - - right, err := lte.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewLessThanOrEqual(left, right)) + return NewLessThanOrEqual(children[0], children[1]), nil } func (lte *LessThanOrEqual) String() string { @@ -544,19 +502,12 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } -// TransformUp implements the Expression interface. -func (in *In) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := in.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (in *In) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) } - - right, err := in.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewIn(left, right)) + return NewIn(children[0], children[1]), nil } func (in *In) String() string { @@ -632,19 +583,12 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } -// TransformUp implements the Expression interface. -func (in *NotIn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := in.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (in *NotIn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) } - - right, err := in.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewNotIn(left, right)) + return NewNotIn(children[0], children[1]), nil } func (in *NotIn) String() string { diff --git a/sql/expression/convert.go b/sql/expression/convert.go index d40640354..bcfe778e0 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -8,8 +8,8 @@ import ( "time" "github.com/spf13/cast" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrConvertExpression is returned when a conversion is not possible. @@ -90,14 +90,12 @@ func (c *Convert) String() string { return fmt.Sprintf("convert(%v, %v)", c.Child, c.castToType) } -// TransformUp implements the Expression interface. -func (c *Convert) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Convert) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - - return f(NewConvert(child, c.castToType)) + return NewConvert(children[0], c.castToType), nil } // Eval implements the Expression interface. diff --git a/sql/expression/default.go b/sql/expression/default.go index 5757a6cff..82c5cb9e7 100644 --- a/sql/expression/default.go +++ b/sql/expression/default.go @@ -53,8 +53,10 @@ func (*DefaultColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { panic("default column is a placeholder node, but Eval was called") } -// TransformUp implements the sql.Expression interface. -func (c *DefaultColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *c - return f(&n) +// WithChildren implements the Expression interface. +func (c *DefaultColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } + return c, nil } diff --git a/sql/expression/function/aggregation/avg.go b/sql/expression/function/aggregation/avg.go index 57153f12a..5e6b2ff49 100644 --- a/sql/expression/function/aggregation/avg.go +++ b/sql/expression/function/aggregation/avg.go @@ -53,13 +53,12 @@ func (a *Avg) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { return sum / float64(rows), nil } -// TransformUp implements AggregationExpression interface. -func (a *Avg) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := a.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *Avg) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1) } - return f(NewAvg(child)) + return NewAvg(children[0]), nil } // NewBuffer implements AggregationExpression interface. (AggregationExpression) diff --git a/sql/expression/function/aggregation/count.go b/sql/expression/function/aggregation/count.go index 96406b11f..b3f900b4d 100644 --- a/sql/expression/function/aggregation/count.go +++ b/sql/expression/function/aggregation/count.go @@ -45,13 +45,12 @@ func (c *Count) String() string { return fmt.Sprintf("COUNT(%s)", c.Child) } -// TransformUp implements the Expression interface. -func (c *Count) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Count) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - return f(NewCount(child)) + return NewCount(children[0]), nil } // Update implements the Aggregation interface. diff --git a/sql/expression/function/aggregation/first.go b/sql/expression/function/aggregation/first.go index b59ac8b12..20ef8af6c 100644 --- a/sql/expression/function/aggregation/first.go +++ b/sql/expression/function/aggregation/first.go @@ -27,13 +27,12 @@ func (f *First) String() string { return fmt.Sprintf("FIRST(%s)", f.Child) } -// TransformUp implements the Transformable interface. -func (f *First) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := f.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the sql.Expression interface. +func (f *First) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) } - return fn(NewFirst(child)) + return NewFirst(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/aggregation/last.go b/sql/expression/function/aggregation/last.go index b76ca4a58..55457a5e5 100644 --- a/sql/expression/function/aggregation/last.go +++ b/sql/expression/function/aggregation/last.go @@ -27,13 +27,12 @@ func (l *Last) String() string { return fmt.Sprintf("LAST(%s)", l.Child) } -// TransformUp implements the Transformable interface. -func (l *Last) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the sql.Expression interface. +func (l *Last) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return fn(NewLast(child)) + return NewLast(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/aggregation/max.go b/sql/expression/function/aggregation/max.go index 40ec86399..e47211f2c 100644 --- a/sql/expression/function/aggregation/max.go +++ b/sql/expression/function/aggregation/max.go @@ -38,13 +38,12 @@ func (m *Max) IsNullable() bool { return false } -// TransformUp implements the Transformable interface. -func (m *Max) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Max) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewMax(child)) + return NewMax(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/aggregation/min.go b/sql/expression/function/aggregation/min.go index 19c08c296..8e73e0812 100644 --- a/sql/expression/function/aggregation/min.go +++ b/sql/expression/function/aggregation/min.go @@ -38,13 +38,12 @@ func (m *Min) IsNullable() bool { return true } -// TransformUp implements the Transformable interface. -func (m *Min) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Min) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewMin(child)) + return NewMin(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/aggregation/sum.go b/sql/expression/function/aggregation/sum.go index dd62e0dfe..09df362be 100644 --- a/sql/expression/function/aggregation/sum.go +++ b/sql/expression/function/aggregation/sum.go @@ -27,13 +27,12 @@ func (m *Sum) String() string { return fmt.Sprintf("SUM(%s)", m.Child) } -// TransformUp implements the Transformable interface. -func (m *Sum) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Sum) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewSum(child)) + return NewSum(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/arraylength.go b/sql/expression/function/arraylength.go index 3002f2df9..00d10cfd2 100644 --- a/sql/expression/function/arraylength.go +++ b/sql/expression/function/arraylength.go @@ -26,13 +26,12 @@ func (f *ArrayLength) String() string { return fmt.Sprintf("array_length(%s)", f.Child) } -// TransformUp implements the Expression interface. -func (f *ArrayLength) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := f.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *ArrayLength) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) } - return fn(NewArrayLength(child)) + return NewArrayLength(children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 1aafa66f7..4c8c1d757 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -32,13 +32,12 @@ func (c *Ceil) String() string { return fmt.Sprintf("CEIL(%s)", c.Child) } -// TransformUp implements the Expression interface. -func (c *Ceil) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Ceil) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - return fn(NewCeil(child)) + return NewCeil(children[0]), nil } // Eval implements the Expression interface. @@ -99,13 +98,12 @@ func (f *Floor) String() string { return fmt.Sprintf("FLOOR(%s)", f.Child) } -// TransformUp implements the Expression interface. -func (f *Floor) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := f.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *Floor) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) } - return fn(NewFloor(child)) + return NewFloor(children[0]), nil } // Eval implements the Expression interface. @@ -269,30 +267,7 @@ func (r *Round) Type() sql.Type { return sql.Int32 } -// TransformUp implements the Expression interface. -func (r *Round) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, 2) - - arg, err := r.Left.TransformUp(f) - if err != nil { - return nil, err - } - args[0] = arg - - args[1] = nil - if r.Right != nil { - var arg sql.Expression - arg, err = r.Right.TransformUp(f) - if err != nil { - return nil, err - } - args[1] = arg - } - - expr, err := NewRound(args...) - if err != nil { - return nil, err - } - - return f(expr) +// WithChildren implements the Expression interface. +func (r *Round) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewRound(children...) } diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index 529a8c455..07f7f64e6 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -59,29 +59,9 @@ func (c *Coalesce) String() string { return fmt.Sprintf("coalesce(%s)", strings.Join(args, ", ")) } -// TransformUp implements the sql.Expression interface. -func (c *Coalesce) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var ( - args = make([]sql.Expression, len(c.args)) - err error - ) - - for i, arg := range c.args { - if arg != nil { - arg, err = arg.TransformUp(fn) - if err != nil { - return nil, err - } - } - args[i] = arg - } - - expr, err := NewCoalesce(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (*Coalesce) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewCoalesce(children...) } // Resolved implements the sql.Expression interface. diff --git a/sql/expression/function/concat.go b/sql/expression/function/concat.go index 77a384a43..56e7bcbab 100644 --- a/sql/expression/function/concat.go +++ b/sql/expression/function/concat.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // Concat joins several strings together. @@ -63,23 +63,9 @@ func (f *Concat) String() string { return fmt.Sprintf("concat(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Concat) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.args)) - for i, arg := range f.args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewConcat(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (*Concat) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewConcat(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/concat_ws.go b/sql/expression/function/concat_ws.go index 7ceffc195..c1e2dacc1 100644 --- a/sql/expression/function/concat_ws.go +++ b/sql/expression/function/concat_ws.go @@ -61,23 +61,9 @@ func (f *ConcatWithSeparator) String() string { return fmt.Sprintf("concat_ws(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *ConcatWithSeparator) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.args)) - for i, arg := range f.args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewConcatWithSeparator(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (*ConcatWithSeparator) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewConcatWithSeparator(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/connection_id.go b/sql/expression/function/connection_id.go index a34ba020e..7ee96ef58 100644 --- a/sql/expression/function/connection_id.go +++ b/sql/expression/function/connection_id.go @@ -19,9 +19,12 @@ func (ConnectionID) Type() sql.Type { return sql.Uint32 } // Resolved implements the sql.Expression interface. func (ConnectionID) Resolved() bool { return true } -// TransformUp implements the sql.Expression interface. -func (ConnectionID) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - return f(ConnectionID{}) +// WithChildren implements the Expression interface. +func (c ConnectionID) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } + return c, nil } // IsNullable implements the sql.Expression interface. diff --git a/sql/expression/function/database.go b/sql/expression/function/database.go index f9d6ac5e2..1246e488c 100644 --- a/sql/expression/function/database.go +++ b/sql/expression/function/database.go @@ -29,9 +29,12 @@ func (*Database) String() string { return "DATABASE()" } -// TransformUp implements the sql.Expression interface. -func (db *Database) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(db) +// WithChildren implements the Expression interface. +func (d *Database) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 0) + } + return NewDatabase(d.catalog)(), nil } // Resolved implements the sql.Expression interface. diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go index cdca25692..919775a42 100644 --- a/sql/expression/function/date.go +++ b/sql/expression/function/date.go @@ -46,18 +46,9 @@ func (d *DateAdd) IsNullable() bool { // Type implements the sql.Expression interface. func (d *DateAdd) Type() sql.Type { return sql.Date } -// TransformUp implements the sql.Expression interface. -func (d *DateAdd) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - date, err := d.Date.TransformUp(f) - if err != nil { - return nil, err - } - interval, err := d.Interval.TransformUp(f) - if err != nil { - return nil, err - } - - return &DateAdd{date, interval.(*expression.Interval)}, nil +// WithChildren implements the Expression interface. +func (d *DateAdd) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDateAdd(children...) } // Eval implements the sql.Expression interface. @@ -130,18 +121,9 @@ func (d *DateSub) IsNullable() bool { // Type implements the sql.Expression interface. func (d *DateSub) Type() sql.Type { return sql.Date } -// TransformUp implements the sql.Expression interface. -func (d *DateSub) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - date, err := d.Date.TransformUp(f) - if err != nil { - return nil, err - } - interval, err := d.Interval.TransformUp(f) - if err != nil { - return nil, err - } - - return &DateSub{date, interval.(*expression.Interval)}, nil +// WithChildren implements the Expression interface. +func (d *DateSub) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDateSub(children...) } // Eval implements the sql.Expression interface. diff --git a/sql/expression/function/explode.go b/sql/expression/function/explode.go index 64e7d85a2..51cd2b66b 100644 --- a/sql/expression/function/explode.go +++ b/sql/expression/function/explode.go @@ -40,14 +40,12 @@ func (e *Explode) String() string { return fmt.Sprintf("EXPLODE(%s)", e.Child) } -// TransformUp implements the sql.Expression interface. -func (e *Explode) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - c, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Explode) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - - return f(NewExplode(c)) + return NewExplode(children[0]), nil } // Generate is a function that generates a row for each value of its child. @@ -84,12 +82,10 @@ func (e *Generate) String() string { return fmt.Sprintf("EXPLODE(%s)", e.Child) } -// TransformUp implements the sql.Expression interface. -func (e *Generate) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - c, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Generate) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - - return f(NewGenerate(c)) + return NewGenerate(children[0]), nil } diff --git a/sql/expression/function/greatest_least.go b/sql/expression/function/greatest_least.go index 0dcc67905..9e822d8ed 100644 --- a/sql/expression/function/greatest_least.go +++ b/sql/expression/function/greatest_least.go @@ -5,8 +5,8 @@ import ( "strconv" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) var ErrUintOverflow = errors.NewKind( @@ -194,23 +194,9 @@ func (f *Greatest) String() string { return fmt.Sprintf("greatest(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Greatest) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.Args)) - for i, arg := range f.Args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewGreatest(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (f *Greatest) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewGreatest(children...) } // Resolved implements the Expression interface. @@ -298,23 +284,9 @@ func (f *Least) String() string { return fmt.Sprintf("least(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Least) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.Args)) - for i, arg := range f.Args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewLeast(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (f *Least) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLeast(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/ifnull.go b/sql/expression/function/ifnull.go index f607da7f1..566d62ea6 100644 --- a/sql/expression/function/ifnull.go +++ b/sql/expression/function/ifnull.go @@ -65,17 +65,10 @@ func (f *IfNull) String() string { return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *IfNull) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *IfNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewIfNull(left, right)) + return NewIfNull(children[0], children[1]), nil } diff --git a/sql/expression/function/isbinary.go b/sql/expression/function/isbinary.go index 150048356..e3edf74d6 100644 --- a/sql/expression/function/isbinary.go +++ b/sql/expression/function/isbinary.go @@ -45,13 +45,12 @@ func (ib *IsBinary) String() string { return fmt.Sprintf("IS_BINARY(%s)", ib.Child) } -// TransformUp implements the Expression interface. -func (ib *IsBinary) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := ib.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (ib *IsBinary) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(ib, len(children), 1) } - return f(NewIsBinary(child)) + return NewIsBinary(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/json_extract.go b/sql/expression/function/json_extract.go index 1dbdb758b..c3f8f65eb 100644 --- a/sql/expression/function/json_extract.go +++ b/sql/expression/function/json_extract.go @@ -108,22 +108,9 @@ func (j *JSONExtract) Children() []sql.Expression { return append([]sql.Expression{j.JSON}, j.Paths...) } -// TransformUp implements the sql.Expression interface. -func (j *JSONExtract) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - json, err := j.JSON.TransformUp(f) - if err != nil { - return nil, err - } - - paths := make([]sql.Expression, len(j.Paths)) - for i, p := range j.Paths { - paths[i], err = p.TransformUp(f) - if err != nil { - return nil, err - } - } - - return f(&JSONExtract{json, paths}) +// WithChildren implements the Expression interface. +func (j *JSONExtract) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewJSONExtract(children...) } func (j *JSONExtract) String() string { diff --git a/sql/expression/function/json_unquote.go b/sql/expression/function/json_unquote.go index 4b5715c4e..8a0c42de3 100644 --- a/sql/expression/function/json_unquote.go +++ b/sql/expression/function/json_unquote.go @@ -33,13 +33,12 @@ func (*JSONUnquote) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (js *JSONUnquote) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - json, err := js.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (js *JSONUnquote) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(js, len(children), 1) } - return f(NewJSONUnquote(json)) + return NewJSONUnquote(children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/length.go b/sql/expression/function/length.go index 5e3d7c8c8..49d46aaf8 100644 --- a/sql/expression/function/length.go +++ b/sql/expression/function/length.go @@ -35,16 +35,13 @@ func NewCharLength(e sql.Expression) sql.Expression { return &Length{expression.UnaryExpression{Child: e}, NumChars} } -// TransformUp implements the sql.Expression interface. -func (l *Length) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *Length) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - nl := *l - nl.Child = child - return &nl, nil + return &Length{expression.UnaryExpression{Child: children[0]}, l.CountType}, nil } // Type implements the sql.Expression interface. diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go index 61afc2d81..eb6355a83 100644 --- a/sql/expression/function/logarithm.go +++ b/sql/expression/function/logarithm.go @@ -5,9 +5,9 @@ import ( "math" "reflect" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" ) // ErrInvalidArgumentForLogarithm is returned when an invalid argument value is passed to a @@ -45,13 +45,12 @@ func (l *LogBase) String() string { } } -// TransformUp implements the Expression interface. -func (l *LogBase) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *LogBase) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return f(NewLogBase(l.base, child)) + return NewLogBase(l.base, children[0]), nil } // Type returns the resultant type of the function. @@ -108,26 +107,9 @@ func (l *Log) String() string { return fmt.Sprintf("log(%s, %s)", l.Left, l.Right) } -// TransformUp implements the Expression interface. -func (l *Log) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, 2) - arg, err := l.Left.TransformUp(f) - if err != nil { - return nil, err - } - args[0] = arg - - arg, err = l.Right.TransformUp(f) - if err != nil { - return nil, err - } - args[1] = arg - expr, err := NewLog(args...) - if err != nil { - return nil, err - } - - return f(expr) +// WithChildren implements the Expression interface. +func (l *Log) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLog(children...) } // Children implements the Expression interface. diff --git a/sql/expression/function/lower_upper.go b/sql/expression/function/lower_upper.go index 2c2b09928..f0d6cdb9d 100644 --- a/sql/expression/function/lower_upper.go +++ b/sql/expression/function/lower_upper.go @@ -2,9 +2,10 @@ package function import ( "fmt" + "strings" + "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" - "strings" ) // Lower is a function that returns the lowercase of the text provided. @@ -43,13 +44,12 @@ func (l *Lower) String() string { return fmt.Sprintf("LOWER(%s)", l.Child) } -// TransformUp implements the Expression interface. -func (l *Lower) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *Lower) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return f(NewLower(child)) + return NewLower(children[0]), nil } // Type implements the Expression interface. @@ -93,13 +93,12 @@ func (u *Upper) String() string { return fmt.Sprintf("UPPER(%s)", u.Child) } -// TransformUp implements the Expression interface. -func (u *Upper) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := u.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (u *Upper) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) } - return f(NewUpper(child)) + return NewUpper(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/nullif.go b/sql/expression/function/nullif.go index fc08b98ce..49b5a5d9d 100644 --- a/sql/expression/function/nullif.go +++ b/sql/expression/function/nullif.go @@ -57,17 +57,10 @@ func (f *NullIf) String() string { return fmt.Sprintf("nullif(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *NullIf) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *NullIf) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewNullIf(left, right)) + return NewNullIf(children[0], children[1]), nil } diff --git a/sql/expression/function/reverse_repeat_replace.go b/sql/expression/function/reverse_repeat_replace.go index 5ec196fc3..cef9e9691 100644 --- a/sql/expression/function/reverse_repeat_replace.go +++ b/sql/expression/function/reverse_repeat_replace.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" ) @@ -39,7 +39,7 @@ func (r *Reverse) Eval( func reverseString(s string) string { r := []rune(s) - for i, j := 0, len(r) - 1; i < j; i, j = i+1, j-1 { + for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 { r[i], r[j] = r[j], r[i] } return string(r) @@ -49,13 +49,12 @@ func (r *Reverse) String() string { return fmt.Sprintf("reverse(%s)", r.Child) } -// TransformUp implements the Expression interface. -func (r *Reverse) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := r.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (r *Reverse) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) } - return f(NewReverse(child)) + return NewReverse(children[0]), nil } // Type implements the Expression interface. @@ -84,18 +83,12 @@ func (r *Repeat) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (r *Repeat) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := r.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := r.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (r *Repeat) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 2) } - return f(NewRepeat(left, right)) + return NewRepeat(children[0], children[1]), nil } // Eval implements the Expression interface. @@ -165,23 +158,12 @@ func (r *Replace) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (r *Replace) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := r.str.TransformUp(f) - if err != nil { - return nil, err - } - - fromStr, err := r.fromStr.TransformUp(f) - if err != nil { - return nil, err - } - - toStr, err := r.toStr.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (r *Replace) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 3) } - return f(NewReplace(str, fromStr, toStr)) + return NewReplace(children[0], children[1], children[2]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/rpad_lpad.go b/sql/expression/function/rpad_lpad.go index 74c1d762c..12b33695b 100644 --- a/sql/expression/function/rpad_lpad.go +++ b/sql/expression/function/rpad_lpad.go @@ -5,8 +5,8 @@ import ( "reflect" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) var ErrDivisionByZero = errors.NewKind("division by zero") @@ -68,24 +68,9 @@ func (p *Pad) String() string { return fmt.Sprintf("rpad(%s, %s, %s)", p.str, p.length, p.padStr) } -// TransformUp implements the Expression interface. -func (p *Pad) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := p.str.TransformUp(f) - if err != nil { - return nil, err - } - - len, err := p.length.TransformUp(f) - if err != nil { - return nil, err - } - - padStr, err := p.padStr.TransformUp(f) - if err != nil { - return nil, err - } - padded, _ := NewPad(p.padType, str, len, padStr) - return f(padded) +// WithChildren implements the Expression interface. +func (p *Pad) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewPad(p.padType, children...) } // Eval implements the Expression interface. diff --git a/sql/expression/function/sleep.go b/sql/expression/function/sleep.go index c8119bcbc..2c672464b 100644 --- a/sql/expression/function/sleep.go +++ b/sql/expression/function/sleep.go @@ -37,7 +37,7 @@ func (s *Sleep) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - time.Sleep(time.Duration(child.(float64) * 1000) * time.Millisecond) + time.Sleep(time.Duration(child.(float64)*1000) * time.Millisecond) return 0, nil } @@ -51,13 +51,12 @@ func (s *Sleep) IsNullable() bool { return false } -// TransformUp implements the Expression interface. -func (s *Sleep) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Sleep) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return f(NewSleep(child)) + return NewSleep(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/soundex.go b/sql/expression/function/soundex.go index 1f33767f1..37774228e 100644 --- a/sql/expression/function/soundex.go +++ b/sql/expression/function/soundex.go @@ -87,13 +87,12 @@ func (s *Soundex) String() string { return fmt.Sprintf("SOUNDEX(%s)", s.Child) } -// TransformUp implements the Expression interface. -func (s *Soundex) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Soundex) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return f(NewSoundex(child)) + return NewSoundex(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/split.go b/sql/expression/function/split.go index 57765c7b6..20e2a49f9 100644 --- a/sql/expression/function/split.go +++ b/sql/expression/function/split.go @@ -76,17 +76,10 @@ func (f *Split) String() string { return fmt.Sprintf("split(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *Split) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *Split) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewSplit(left, right)) + return NewSplit(children[0], children[1]), nil } diff --git a/sql/expression/function/sqrt_power.go b/sql/expression/function/sqrt_power.go index c5bdf630a..020c8b7bf 100644 --- a/sql/expression/function/sqrt_power.go +++ b/sql/expression/function/sqrt_power.go @@ -4,8 +4,8 @@ import ( "fmt" "math" - "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Sqrt is a function that returns the square value of the number provided. @@ -32,13 +32,12 @@ func (s *Sqrt) IsNullable() bool { return s.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (s *Sqrt) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Sqrt) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return fn(NewSqrt(child)) + return NewSqrt(children[0]), nil } // Eval implements the Expression interface. @@ -86,19 +85,12 @@ func (p *Power) String() string { return fmt.Sprintf("power(%s, %s)", p.Left, p.Right) } -// TransformUp implements the Expression interface. -func (p *Power) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := p.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (p *Power) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - - right, err := p.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewPower(left, right)) + return NewPower(children[0], children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/substring.go b/sql/expression/function/substring.go index bc1337d88..c5227b9bc 100644 --- a/sql/expression/function/substring.go +++ b/sql/expression/function/substring.go @@ -142,32 +142,9 @@ func (s *Substring) Resolved() bool { // Type implements the Expression interface. func (*Substring) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (s *Substring) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := s.str.TransformUp(f) - if err != nil { - return nil, err - } - - start, err := s.start.TransformUp(f) - if err != nil { - return nil, err - } - - // It is safe to omit the errors of NewSubstring here because to be able to call - // this method, you need a valid instance of Substring, so the arity must be correct - // and that's the only error NewSubstring can return. - var sub sql.Expression - if s.len != nil { - len, err := s.len.TransformUp(f) - if err != nil { - return nil, err - } - sub, _ = NewSubstring(str, start, len) - } else { - sub, _ = NewSubstring(str, start) - } - return f(sub) +/// WithChildren implements the Expression interface. +func (*Substring) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewSubstring(children...) } // SubstringIndex returns the substring from string str before count occurrences of the delimiter delim. @@ -273,22 +250,10 @@ func (s *SubstringIndex) Resolved() bool { // Type implements the Expression interface. func (*SubstringIndex) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (s *SubstringIndex) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := s.str.TransformUp(f) - if err != nil { - return nil, err - } - - delim, err := s.delim.TransformUp(f) - if err != nil { - return nil, err - } - - count, err := s.count.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *SubstringIndex) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 3) } - - return f(NewSubstringIndex(str, delim, count)) + return NewSubstringIndex(children[0], children[1], children[2]), nil } diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index 345c5b98d..2385aec0d 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -66,14 +66,12 @@ func (y *Year) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, y.UnaryExpression, row, year) } -// TransformUp implements the Expression interface. -func (y *Year) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := y.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (y *Year) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(y, len(children), 1) } - - return f(NewYear(child)) + return NewYear(children[0]), nil } // Month is a function that returns the month of a date. @@ -96,14 +94,12 @@ func (m *Month) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, m.UnaryExpression, row, month) } -// TransformUp implements the Expression interface. -func (m *Month) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Month) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - - return f(NewMonth(child)) + return NewMonth(children[0]), nil } // Day is a function that returns the day of a date. @@ -126,14 +122,12 @@ func (d *Day) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, day) } -// TransformUp implements the Expression interface. -func (d *Day) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Day) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDay(child)) + return NewDay(children[0]), nil } // Weekday is a function that returns the weekday of a date where 0 = Monday, @@ -157,14 +151,12 @@ func (d *Weekday) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, weekday) } -// TransformUp implements the Expression interface. -func (d *Weekday) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Weekday) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewWeekday(child)) + return NewWeekday(children[0]), nil } // Hour is a function that returns the hour of a date. @@ -187,14 +179,12 @@ func (h *Hour) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, h.UnaryExpression, row, hour) } -// TransformUp implements the Expression interface. -func (h *Hour) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := h.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (h *Hour) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) } - - return f(NewHour(child)) + return NewHour(children[0]), nil } // Minute is a function that returns the minute of a date. @@ -217,14 +207,12 @@ func (m *Minute) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, m.UnaryExpression, row, minute) } -// TransformUp implements the Expression interface. -func (m *Minute) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Minute) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - - return f(NewMinute(child)) + return NewMinute(children[0]), nil } // Second is a function that returns the second of a date. @@ -247,14 +235,12 @@ func (s *Second) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, s.UnaryExpression, row, second) } -// TransformUp implements the Expression interface. -func (s *Second) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Second) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - - return f(NewSecond(child)) + return NewSecond(children[0]), nil } // DayOfWeek is a function that returns the day of the week from a date where @@ -278,14 +264,12 @@ func (d *DayOfWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, dayOfWeek) } -// TransformUp implements the Expression interface. -func (d *DayOfWeek) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *DayOfWeek) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDayOfWeek(child)) + return NewDayOfWeek(children[0]), nil } // DayOfYear is a function that returns the day of the year from a date. @@ -308,14 +292,12 @@ func (d *DayOfYear) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, dayOfYear) } -// TransformUp implements the Expression interface. -func (d *DayOfYear) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *DayOfYear) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDayOfYear(child)) + return NewDayOfYear(children[0]), nil } func datePartFunc(fn func(time.Time) int) func(interface{}) interface{} { @@ -403,23 +385,9 @@ func (d *YearWeek) IsNullable() bool { return d.date.IsNullable() } -// TransformUp implements the Expression interface. -func (d *YearWeek) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - date, err := d.date.TransformUp(f) - if err != nil { - return nil, err - } - - mode, err := d.mode.TransformUp(f) - if err != nil { - return nil, err - } - - yw, err := NewYearWeek(date, mode) - if err != nil { - return nil, err - } - return f(yw) +// WithChildren implements the Expression interface. +func (*YearWeek) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewYearWeek(children...) } // Following solution of YearWeek was taken from tidb: https://github.com/pingcap/tidb/blob/master/types/mytime.go @@ -567,9 +535,12 @@ func (n *Now) Eval(*sql.Context, sql.Row) (interface{}, error) { return n.clock(), nil } -// TransformUp implements the sql.Expression interface. -func (n *Now) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - return f(n) +// WithChildren implements the Expression interface. +func (n *Now) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + return n, nil } // Date a function takes the DATE part out from a datetime expression. @@ -598,12 +569,10 @@ func (d *Date) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { }) } -// TransformUp implements the sql.Expression interface. -func (d *Date) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Date) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDate(child)) + return NewDate(children[0]), nil } diff --git a/sql/expression/function/tobase64_frombase64.go b/sql/expression/function/tobase64_frombase64.go index 543d482f7..f3c638983 100644 --- a/sql/expression/function/tobase64_frombase64.go +++ b/sql/expression/function/tobase64_frombase64.go @@ -72,13 +72,12 @@ func (t *ToBase64) IsNullable() bool { return t.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (t *ToBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (t *ToBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewToBase64(child)) + return NewToBase64(children[0]), nil } // Type implements the Expression interface. @@ -86,7 +85,6 @@ func (t *ToBase64) Type() sql.Type { return sql.Text } - // FromBase64 is a function to decode a Base64-formatted string // using the same dialect that MySQL's FROM_BASE64 uses type FromBase64 struct { @@ -133,13 +131,12 @@ func (t *FromBase64) IsNullable() bool { return t.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (t *FromBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (t *FromBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewFromBase64(child)) + return NewFromBase64(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/trim_ltrim_rtrim.go b/sql/expression/function/trim_ltrim_rtrim.go index 9aacacb9b..b08704dfb 100644 --- a/sql/expression/function/trim_ltrim_rtrim.go +++ b/sql/expression/function/trim_ltrim_rtrim.go @@ -6,11 +6,12 @@ import ( "strings" "unicode" - "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) type trimType rune + const ( lTrimType trimType = 'l' rTrimType trimType = 'r' @@ -54,14 +55,12 @@ func (t *Trim) IsNullable() bool { return t.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (t *Trim) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (t *Trim) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - - return f(NewTrim(t.trimType, str)) + return NewTrim(t.trimType, children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/version.go b/sql/expression/function/version.go index b100128f6..eafeb4454 100644 --- a/sql/expression/function/version.go +++ b/sql/expression/function/version.go @@ -30,9 +30,12 @@ func (f Version) String() string { return "VERSION()" } -// TransformUp implements the Expression interface. -func (f Version) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(f) +// WithChildren implements the Expression interface. +func (f Version) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 0) + } + return f, nil } // Resolved implements the Expression interface. diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index cb8ac9aae..e6a5884dc 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -3,8 +3,8 @@ package expression import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // GetField is an expression to get the field of a table. @@ -72,10 +72,12 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -// TransformUp implements the Expression interface. -func (p *GetField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *p - return f(&n) +// WithChildren implements the Expression interface. +func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } func (p *GetField) String() string { @@ -124,7 +126,10 @@ func (f *GetSessionField) Resolved() bool { return true } // String implements the sql.Expression interface. func (f *GetSessionField) String() string { return "@@" + f.name } -// TransformUp implements the sql.Expression interface. -func (f *GetSessionField) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(f) +// WithChildren implements the Expression interface. +func (f *GetSessionField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 0) + } + return f, nil } diff --git a/sql/expression/interval.go b/sql/expression/interval.go index 6ec5cfa91..a175d55b5 100644 --- a/sql/expression/interval.go +++ b/sql/expression/interval.go @@ -7,8 +7,8 @@ import ( "strings" "time" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // Interval defines a time duration. @@ -148,14 +148,12 @@ func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) return &td, nil } -// TransformUp implements the sql.Expression interface. -func (i *Interval) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := i.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (i *Interval) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1) } - - return NewInterval(child, i.Unit), nil + return NewInterval(children[0], i.Unit), nil } func (i *Interval) String() string { diff --git a/sql/expression/isnull.go b/sql/expression/isnull.go index 252b95f70..a9ae575d5 100644 --- a/sql/expression/isnull.go +++ b/sql/expression/isnull.go @@ -36,11 +36,10 @@ func (e IsNull) String() string { return e.Child.String() + " IS NULL" } -// TransformUp implements the Expression interface. -func (e *IsNull) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *IsNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewIsNull(child)) + return NewIsNull(children[0]), nil } diff --git a/sql/expression/like.go b/sql/expression/like.go index 6c361b22c..1fcf08bf8 100644 --- a/sql/expression/like.go +++ b/sql/expression/like.go @@ -107,19 +107,12 @@ func (l *Like) String() string { return fmt.Sprintf("%s LIKE %s", l.Left, l.Right) } -// TransformUp implements the sql.Expression interface. -func (l *Like) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := l.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *Like) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 2) } - - right, err := l.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewLike(left, right)) + return NewLike(children[0], children[1]), nil } func patternToRegex(pattern string) string { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index df9f77be7..02e88e0a9 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -53,10 +53,12 @@ func (p *Literal) String() string { } } -// TransformUp implements the Expression interface. -func (p *Literal) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *p - return f(&n) +// WithChildren implements the Expression interface. +func (p *Literal) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } // Children implements the Expression interface. diff --git a/sql/expression/logic.go b/sql/expression/logic.go index 6551c793a..08a31c087 100644 --- a/sql/expression/logic.go +++ b/sql/expression/logic.go @@ -68,19 +68,12 @@ func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return true, nil } -// TransformUp implements the Expression interface. -func (a *And) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := a.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *And) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) } - - right, err := a.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewAnd(left, right)) + return NewAnd(children[0], children[1]), nil } // Or checks whether one of the two given expressions is true. @@ -125,17 +118,10 @@ func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return rval == true, nil } -// TransformUp implements the Expression interface. -func (o *Or) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := o.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (o *Or) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 2) } - - right, err := o.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewOr(left, right)) + return NewOr(children[0], children[1]), nil } diff --git a/sql/expression/star.go b/sql/expression/star.go index c2ce8691f..5d6603be1 100644 --- a/sql/expression/star.go +++ b/sql/expression/star.go @@ -55,8 +55,10 @@ func (*Star) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { panic("star is just a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (s *Star) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *s - return f(&n) +// WithChildren implements the Expression interface. +func (s *Star) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + return s, nil } diff --git a/sql/expression/transform.go b/sql/expression/transform.go new file mode 100644 index 000000000..ccf3d4276 --- /dev/null +++ b/sql/expression/transform.go @@ -0,0 +1,28 @@ +package expression + +import "github.com/src-d/go-mysql-server/sql" + +// TransformUp applies a transformation function to the given expression from the +// bottom up. +func TransformUp(e sql.Expression, f sql.TransformExprFunc) (sql.Expression, error) { + children := e.Children() + if len(children) == 0 { + return f(e) + } + + newChildren := make([]sql.Expression, len(children)) + for i, c := range children { + c, err := TransformUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + e, err := e.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(e) +} diff --git a/sql/expression/tuple.go b/sql/expression/tuple.go index 05af0367a..11d35e1f5 100644 --- a/sql/expression/tuple.go +++ b/sql/expression/tuple.go @@ -77,18 +77,12 @@ func (t Tuple) Type() sql.Type { return sql.Tuple(types...) } -// TransformUp implements the Expression interface. -func (t Tuple) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var exprs = make([]sql.Expression, len(t)) - for i, e := range t { - var err error - exprs[i], err = f(e) - if err != nil { - return nil, err - } +// WithChildren implements the Expression interface. +func (t Tuple) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(t) { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), len(t)) } - - return f(Tuple(exprs)) + return NewTuple(children...), nil } // Children implements the Expression interface. diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 276367cd7..6580655d2 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -64,10 +64,12 @@ func (*UnresolvedColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) panic("unresolved column is a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (uc *UnresolvedColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *uc - return f(&n) +// WithChildren implements the Expression interface. +func (uc *UnresolvedColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(uc, len(children), 0) + } + return uc, nil } // UnresolvedFunction represents a function that is not yet resolved. @@ -126,16 +128,10 @@ func (*UnresolvedFunction) Eval(ctx *sql.Context, r sql.Row) (interface{}, error panic("unresolved function is a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (uf *UnresolvedFunction) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var rc []sql.Expression - for _, c := range uf.Arguments { - ct, err := c.TransformUp(f) - if err != nil { - return nil, err - } - rc = append(rc, ct) +// WithChildren implements the Expression interface. +func (uf *UnresolvedFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(uf.Arguments) { + return nil, sql.ErrInvalidChildrenNumber.New(uf, len(children), len(uf.Arguments)) } - - return f(NewUnresolvedFunction(uf.name, uf.IsAggregate, rc...)) + return NewUnresolvedFunction(uf.name, uf.IsAggregate, children...), nil } diff --git a/sql/index_test.go b/sql/index_test.go index da42c2464..1e631cba5 100644 --- a/sql/index_test.go +++ b/sql/index_test.go @@ -401,8 +401,8 @@ var _ Expression = (*dummyExpr)(nil) func (dummyExpr) Children() []Expression { return nil } func (dummyExpr) Eval(*Context, Row) (interface{}, error) { panic("not implemented") } -func (e dummyExpr) TransformUp(fn TransformExprFunc) (Expression, error) { - return fn(e) +func (e dummyExpr) WithChildren(children ...Expression) (Expression, error) { + return e, nil } func (e dummyExpr) String() string { return fmt.Sprintf("dummyExpr{%d, %s}", e.index, e.colName) } func (dummyExpr) IsNullable() bool { return false } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index efb332dc0..6dd82e426 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -168,7 +168,7 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { } name := strings.TrimSpace(e.Name.Lowered()) - if expr, err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + if expr, err = expression.TransformUp(expr, func(e sql.Expression) (sql.Expression, error) { if _, ok := e.(*expression.DefaultColumn); ok { return e, nil } diff --git a/sql/plan/common.go b/sql/plan/common.go index 061a9f13a..beec46177 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -57,20 +57,3 @@ func expressionsResolved(exprs ...sql.Expression) bool { return true } - -func transformExpressionsUp( - f sql.TransformExprFunc, - exprs []sql.Expression, -) ([]sql.Expression, error) { - - var es []sql.Expression - for _, e := range exprs { - te, err := e.TransformUp(f) - if err != nil { - return nil, err - } - es = append(es, te) - } - - return es, nil -} diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index ce170bbd4..f45de2f7e 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -8,9 +8,9 @@ import ( opentracing "github.com/opentracing/opentracing-go" otlog "github.com/opentracing/opentracing-go/log" "github.com/sirupsen/logrus" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" ) var ( @@ -260,58 +260,28 @@ func (c *CreateIndex) Expressions() []sql.Expression { return c.Exprs } -// TransformExpressions implements the Expressioner interface. -func (c *CreateIndex) TransformExpressions(fn sql.TransformExprFunc) (sql.Node, error) { - var exprs = make([]sql.Expression, len(c.Exprs)) - var err error - for i, e := range c.Exprs { - exprs[i], err = e.TransformUp(fn) - if err != nil { - return nil, err - } +// WithExpressions implements the Expressioner interface. +func (c *CreateIndex) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(c.Exprs) { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(exprs), len(c.Exprs)) } nc := *c nc.Exprs = exprs - return &nc, nil } -// TransformExpressionsUp implements the Node interface. -func (c *CreateIndex) TransformExpressionsUp(fn sql.TransformExprFunc) (sql.Node, error) { - table, err := c.Table.TransformExpressionsUp(fn) - if err != nil { - return nil, err - } - - var exprs = make([]sql.Expression, len(c.Exprs)) - for i, e := range c.Exprs { - exprs[i], err = e.TransformUp(fn) - if err != nil { - return nil, err - } +// WithChildren implements the Node interface. +func (c *CreateIndex) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } nc := *c - nc.Table = table - nc.Exprs = exprs - + nc.Table = children[0] return &nc, nil } -// TransformUp implements the Node interface. -func (c *CreateIndex) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) { - table, err := c.Table.TransformUp(fn) - if err != nil { - return nil, err - } - - nc := *c - nc.Table = table - - return fn(&nc) -} - // getColumnsAndPrepareExpressions extracts the unique columns required by all // those expressions and fixes the indexes of the GetFields in the expressions // to match a row with only the returned columns in that same order. @@ -323,7 +293,7 @@ func getColumnsAndPrepareExpressions( var expressions = make([]sql.Expression, len(exprs)) for i, e := range exprs { - ex, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + ex, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { gf, ok := e.(*expression.GetField) if !ok { return e, nil diff --git a/sql/plan/cross_join.go b/sql/plan/cross_join.go index 6943719b9..f7dc1d80d 100644 --- a/sql/plan/cross_join.go +++ b/sql/plan/cross_join.go @@ -66,34 +66,13 @@ func (p *CrossJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -// TransformUp implements the Transformable interface. -func (p *CrossJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := p.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewCrossJoin(left, right)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *CrossJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := p.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *CrossJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - return NewCrossJoin(left, right), nil + return NewCrossJoin(children[0], children[1]), nil } func (p *CrossJoin) String() string { diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index b92075e52..3eb97331e 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -1,8 +1,8 @@ package plan import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) // ErrCreateTable is thrown when the database doesn't support table creation @@ -64,13 +64,11 @@ func (c *CreateTable) Schema() sql.Schema { return nil } // Children implements the Node interface. func (c *CreateTable) Children() []sql.Node { return nil } -// TransformUp implements the Transformable interface. -func (c *CreateTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewCreateTable(c.db, c.name, c.schema)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (c *CreateTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { +// WithChildren implements the Node interface. +func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } return c, nil } diff --git a/sql/plan/describe.go b/sql/plan/describe.go index aed8aeaf5..e84cdc8a0 100644 --- a/sql/plan/describe.go +++ b/sql/plan/describe.go @@ -33,22 +33,13 @@ func (d *Describe) RowIter(ctx *sql.Context) (sql.RowIter, error) { return &describeIter{schema: d.Child.Schema()}, nil } -// TransformUp implements the Transformable interface. -func (d *Describe) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *Describe) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewDescribe(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *Describe) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewDescribe(child), nil + return NewDescribe(children[0]), nil } func (d Describe) String() string { @@ -116,22 +107,11 @@ func (d *DescribeQuery) String() string { return pr.String() } -// TransformUp implements the Node interface. -func (d *DescribeQuery) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewDescribeQuery(d.Format, child)) -} - -// TransformExpressionsUp implements the Node interface. -func (d *DescribeQuery) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *DescribeQuery) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return NewDescribeQuery(d.Format, child), nil + return NewDescribeQuery(d.Format, children[0]), nil } diff --git a/sql/plan/distinct.go b/sql/plan/distinct.go index 556c1377f..e14aee5fb 100644 --- a/sql/plan/distinct.go +++ b/sql/plan/distinct.go @@ -37,22 +37,13 @@ func (d *Distinct) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newDistinctIter(it)), nil } -// TransformUp implements the Transformable interface. -func (d *Distinct) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *Distinct) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewDistinct(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *Distinct) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewDistinct(child), nil + return NewDistinct(children[0]), nil } func (d Distinct) String() string { @@ -135,22 +126,13 @@ func (d *OrderedDistinct) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newOrderedDistinctIter(it, d.Child.Schema())), nil } -// TransformUp implements the Transformable interface. -func (d *OrderedDistinct) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *OrderedDistinct) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewOrderedDistinct(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *OrderedDistinct) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewOrderedDistinct(child), nil + return NewOrderedDistinct(children[0]), nil } func (d OrderedDistinct) String() string { diff --git a/sql/plan/drop_index.go b/sql/plan/drop_index.go index e6917ec04..f08caf711 100644 --- a/sql/plan/drop_index.go +++ b/sql/plan/drop_index.go @@ -1,9 +1,9 @@ package plan import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/similartext" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) var ( @@ -104,26 +104,13 @@ func (d *DropIndex) String() string { return pr.String() } -// TransformExpressionsUp implements the Node interface. -func (d *DropIndex) TransformExpressionsUp(fn sql.TransformExprFunc) (sql.Node, error) { - t, err := d.Table.TransformExpressionsUp(fn) - if err != nil { - return nil, err - } - - nc := *d - nc.Table = t - return &nc, nil -} - -// TransformUp implements the Node interface. -func (d *DropIndex) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) { - t, err := d.Table.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *DropIndex) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - nc := *d - nc.Table = t - return fn(&nc) + nd := *d + nd.Table = children[0] + return &nd, nil } diff --git a/sql/plan/empty_table.go b/sql/plan/empty_table.go index 9a10ebe73..198cef41d 100644 --- a/sql/plan/empty_table.go +++ b/sql/plan/empty_table.go @@ -16,12 +16,11 @@ func (emptyTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), nil } -// TransformUp implements the Transformable interface. -func (e *emptyTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(e) -} +// WithChildren implements the Node interface. +func (e *emptyTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (e *emptyTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return e, nil } diff --git a/sql/plan/exchange.go b/sql/plan/exchange.go index 9c4347369..8ff51ee6f 100644 --- a/sql/plan/exchange.go +++ b/sql/plan/exchange.go @@ -61,24 +61,13 @@ func (e *Exchange) String() string { return p.String() } -// TransformUp implements the sql.Node interface. -func (e *Exchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (e *Exchange) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewExchange(e.Parallelism, child)) -} - -// TransformExpressionsUp implements the sql.Node interface. -func (e *Exchange) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := e.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewExchange(e.Parallelism, child), nil + return NewExchange(e.Parallelism, children[0]), nil } type exchangeRowIter struct { @@ -208,7 +197,7 @@ func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) { } func (it *exchangeRowIter) iterPartition(p sql.Partition) { - node, err := it.tree.TransformUp(func(n sql.Node) (sql.Node, error) { + node, err := TransformUp(it.tree, func(n sql.Node) (sql.Node, error) { if t, ok := n.(sql.Table); ok { return &exchangePartition{p, t}, nil } @@ -310,10 +299,11 @@ func (p *exchangePartition) Schema() sql.Schema { return p.table.Schema() } -func (p *exchangePartition) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { - return p, nil -} +// WithChildren implements the Node interface. +func (p *exchangePartition) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -func (p *exchangePartition) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/exchange_test.go b/sql/plan/exchange_test.go index eca284216..5a8e8e317 100644 --- a/sql/plan/exchange_test.go +++ b/sql/plan/exchange_test.go @@ -6,9 +6,9 @@ import ( "io" "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" ) func TestExchange(t *testing.T) { @@ -106,11 +106,12 @@ type partitionable struct { rowsPerPartition int } -func (p *partitionable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} +// WithChildren implements the Node interface. +func (p *partitionable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -func (p *partitionable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index a86c51eab..f160b3f66 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -36,28 +36,22 @@ func (p *Filter) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, NewFilterIter(ctx, p.Expression, i)), nil } -// TransformUp implements the Transformable interface. -func (p *Filter) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Filter) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return f(NewFilter(p.Expression, child)) + + return NewFilter(p.Expression, children[0]), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *Filter) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - expr, err := p.Expression.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (p *Filter) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), 1) } - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewFilter(expr, child), nil + return NewFilter(exprs[0], p.Child), nil } func (p *Filter) String() string { @@ -72,16 +66,6 @@ func (p *Filter) Expressions() []sql.Expression { return []sql.Expression{p.Expression} } -// TransformExpressions implements the Expressioner interface. -func (p *Filter) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - e, err := p.Expression.TransformUp(f) - if err != nil { - return nil, err - } - - return NewFilter(e, p.Child), nil -} - // FilterIter is an iterator that filters another iterator and skips rows that // don't match the given condition. type FilterIter struct { diff --git a/sql/plan/generate.go b/sql/plan/generate.go index 841a45890..c259b8d8d 100644 --- a/sql/plan/generate.go +++ b/sql/plan/generate.go @@ -46,48 +46,30 @@ func (g *Generate) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -func (g *Generate) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - col, err := g.Column.TransformUp(f) - if err != nil { - return nil, err - } - - field, ok := col.(*expression.GetField) - if !ok { - return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) - } - - return NewGenerate(g.Child, field), nil -} +// Expressions implements the Expressioner interface. +func (g *Generate) Expressions() []sql.Expression { return []sql.Expression{g.Column} } -// TransformUp implements the sql.Node interface. -func (g *Generate) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := g.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (g *Generate) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(g, len(children), 1) } - return f(NewGenerate(child, g.Column)) + return NewGenerate(children[0], g.Column), nil } -// TransformExpressionsUp implements the sql.Node interface. -func (g *Generate) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := g.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - col, err := g.Column.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (g *Generate) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(g, len(exprs), 1) } - field, ok := col.(*expression.GetField) + gf, ok := exprs[0].(*expression.GetField) if !ok { - return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) + return nil, fmt.Errorf("Generate expects child to be expression.GetField, but is %T", exprs[0]) } - return NewGenerate(child, field), nil + return NewGenerate(g.Child, gf), nil } func (g *Generate) String() string { diff --git a/sql/plan/generate_test.go b/sql/plan/generate_test.go index ba35f33cf..7f32db68c 100644 --- a/sql/plan/generate_test.go +++ b/sql/plan/generate_test.go @@ -3,9 +3,9 @@ package plan import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" ) func TestGenerateRowIter(t *testing.T) { @@ -80,9 +80,6 @@ func (n *fakeNode) Resolved() bool { return true } func (n *fakeNode) Schema() sql.Schema { return n.schema } func (n *fakeNode) RowIter(*sql.Context) (sql.RowIter, error) { return n.iter, nil } func (n *fakeNode) String() string { return "fakeNode" } -func (n *fakeNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { - panic("placeholder") -} -func (n *fakeNode) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { +func (*fakeNode) WithChildren(children ...sql.Node) (sql.Node, error) { panic("placeholder") } diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go index a61b9204a..620167568 100644 --- a/sql/plan/group_by.go +++ b/sql/plan/group_by.go @@ -7,9 +7,9 @@ import ( "strings" opentracing "github.com/opentracing/opentracing-go" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrGroupBy is returned when the aggregation is not supported. @@ -93,33 +93,34 @@ func (p *GroupBy) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, iter), nil } -// TransformUp implements the Transformable interface. -func (p *GroupBy) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *GroupBy) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return f(NewGroupBy(p.Aggregate, p.Grouping, child)) + + return NewGroupBy(p.Aggregate, p.Grouping, children[0]), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *GroupBy) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - aggregate, err := transformExpressionsUp(f, p.Aggregate) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *GroupBy) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + expected := len(p.Aggregate) + len(p.Grouping) + if len(exprs) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), expected) } - grouping, err := transformExpressionsUp(f, p.Grouping) - if err != nil { - return nil, err + var agg = make([]sql.Expression, len(p.Aggregate)) + for i := 0; i < len(p.Aggregate); i++ { + agg[i] = exprs[i] } - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err + var grouping = make([]sql.Expression, len(p.Grouping)) + offset := len(p.Aggregate) + for i := 0; i < len(p.Grouping); i++ { + grouping[i] = exprs[i+offset] } - return NewGroupBy(aggregate, grouping, child), nil + return NewGroupBy(agg, grouping, p.Child), nil } func (p *GroupBy) String() string { @@ -152,21 +153,6 @@ func (p *GroupBy) Expressions() []sql.Expression { return exprs } -// TransformExpressions implements the Expressioner interface. -func (p *GroupBy) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - agg, err := transformExpressionsUp(f, p.Aggregate) - if err != nil { - return nil, err - } - - group, err := transformExpressionsUp(f, p.Grouping) - if err != nil { - return nil, err - } - - return NewGroupBy(agg, group, p.Child), nil -} - type groupByIter struct { aggregate []sql.Expression child sql.RowIter diff --git a/sql/plan/having.go b/sql/plan/having.go index 4a7f81b85..48a8bde25 100644 --- a/sql/plan/having.go +++ b/sql/plan/having.go @@ -24,39 +24,22 @@ func (h *Having) Resolved() bool { return h.Cond.Resolved() && h.Child.Resolved( // Expressions implements the sql.Expressioner interface. func (h *Having) Expressions() []sql.Expression { return []sql.Expression{h.Cond} } -// TransformExpressions implements the sql.Expressioner interface. -func (h *Having) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - e, err := h.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return &Having{h.UnaryNode, e}, nil -} - -// TransformExpressionsUp implements the sql.Node interface. -func (h *Having) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := h.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (h *Having) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) } - e, err := h.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return &Having{UnaryNode{child}, e}, nil + return NewHaving(h.Cond, children[0]), nil } -// TransformUp implements the sql.Node interface. -func (h *Having) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := h.Child.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (h *Having) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(exprs), 1) } - return f(&Having{UnaryNode{child}, h.Cond}) + return NewHaving(exprs[0], h.Child), nil } // RowIter implements the sql.Node interface. diff --git a/sql/plan/insert.go b/sql/plan/insert.go index ebd6482d3..538fc160a 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -4,9 +4,9 @@ import ( "io" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" ) // ErrInsertIntoNotSupported is thrown when a table doesn't support inserts @@ -123,34 +123,13 @@ func (p *InsertInto) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(sql.NewRow(int64(n))), nil } -// TransformUp implements the Transformable interface. -func (p *InsertInto) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := p.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewInsertInto(left, right, p.Columns)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *InsertInto) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := p.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - return NewInsertInto(left, right, p.Columns), nil + return NewInsertInto(children[0], children[1], p.Columns), nil } func (p InsertInto) String() string { diff --git a/sql/plan/join.go b/sql/plan/join.go index 1b3475908..a287a424a 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -83,39 +83,22 @@ func (j *InnerJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { return joinRowIter(ctx, innerJoin, j.Left, j.Right, j.Cond) } -// TransformUp implements the Transformable interface. -func (j *InnerJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *InnerJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) } - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewInnerJoin(left, right, j.Cond)) + return NewInnerJoin(children[0], children[1], j.Cond), nil } -// TransformExpressionsUp implements the Transformable interface. -func (j *InnerJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (j *InnerJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) } - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewInnerJoin(left, right, cond), nil + return NewInnerJoin(j.Left, j.Right, exprs[0]), nil } func (j *InnerJoin) String() string { @@ -130,16 +113,6 @@ func (j *InnerJoin) Expressions() []sql.Expression { return []sql.Expression{j.Cond} } -// TransformExpressions implements the Expressioner interface. -func (j *InnerJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewInnerJoin(j.Left, j.Right, cond), nil -} - // LeftJoin is a left join between two tables. type LeftJoin struct { BinaryNode @@ -172,39 +145,22 @@ func (j *LeftJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { return joinRowIter(ctx, leftJoin, j.Left, j.Right, j.Cond) } -// TransformUp implements the Transformable interface. -func (j *LeftJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *LeftJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 1) } - return f(NewLeftJoin(left, right, j.Cond)) + return NewLeftJoin(children[0], children[1], j.Cond), nil } -// TransformExpressionsUp implements the Transformable interface. -func (j *LeftJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (j *LeftJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) } - return NewLeftJoin(left, right, cond), nil + return NewLeftJoin(j.Left, j.Right, exprs[0]), nil } func (j *LeftJoin) String() string { @@ -219,16 +175,6 @@ func (j *LeftJoin) Expressions() []sql.Expression { return []sql.Expression{j.Cond} } -// TransformExpressions implements the Expressioner interface. -func (j *LeftJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewLeftJoin(j.Left, j.Right, cond), nil -} - // RightJoin is a left join between two tables. type RightJoin struct { BinaryNode @@ -261,39 +207,22 @@ func (j *RightJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { return joinRowIter(ctx, rightJoin, j.Left, j.Right, j.Cond) } -// TransformUp implements the Transformable interface. -func (j *RightJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *RightJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) } - return f(NewRightJoin(left, right, j.Cond)) + return NewRightJoin(children[0], children[1], j.Cond), nil } -// TransformExpressionsUp implements the Transformable interface. -func (j *RightJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (j *RightJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) } - return NewRightJoin(left, right, cond), nil + return NewRightJoin(j.Left, j.Right, exprs[0]), nil } func (j *RightJoin) String() string { @@ -308,16 +237,6 @@ func (j *RightJoin) Expressions() []sql.Expression { return []sql.Expression{j.Cond} } -// TransformExpressions implements the Expressioner interface. -func (j *RightJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewRightJoin(j.Left, j.Right, cond), nil -} - type joinType byte const ( diff --git a/sql/plan/limit.go b/sql/plan/limit.go index 212ccdd93..9ec72805e 100644 --- a/sql/plan/limit.go +++ b/sql/plan/limit.go @@ -7,8 +7,6 @@ import ( "github.com/src-d/go-mysql-server/sql" ) -var _ sql.Node = &Limit{} - // Limit is a node that only allows up to N rows to be retrieved. type Limit struct { UnaryNode @@ -40,22 +38,12 @@ func (l *Limit) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &limitIter{l, 0, li}), nil } -// TransformUp implements the Transformable interface. -func (l *Limit) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewLimit(l.Limit, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (l *Limit) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := l.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (l *Limit) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return NewLimit(l.Limit, child), nil + return NewLimit(l.Limit, children[0]), nil } func (l Limit) String() string { diff --git a/sql/plan/lock.go b/sql/plan/lock.go index d7582c947..8e51edec6 100644 --- a/sql/plan/lock.go +++ b/sql/plan/lock.go @@ -3,8 +3,8 @@ package plan import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // TableLock is a read or write lock on a table. @@ -25,8 +25,6 @@ func NewLockTables(locks []*TableLock) *LockTables { return &LockTables{Locks: locks} } -var _ sql.Node = (*LockTables)(nil) - // Children implements the sql.Node interface. func (t *LockTables) Children() []sql.Node { var children = make([]sql.Node, len(t.Locks)) @@ -89,25 +87,21 @@ func (t *LockTables) String() string { return p.String() } -// TransformUp implements the sql.Node interface. -func (t *LockTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - var children = make([]*TableLock, len(t.Locks)) - for i, l := range t.Locks { - node, err := l.Table.TransformUp(f) - if err != nil { - return nil, err - } - children[i] = &TableLock{node, l.Write} +// WithChildren implements the Node interface. +func (t *LockTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != len(t.Locks) { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), len(t.Locks)) } - nt := *t - nt.Locks = children - return f(&nt) -} + var locks = make([]*TableLock, len(t.Locks)) + for i, n := range children { + locks[i] = &TableLock{ + Table: n, + Write: t.Locks[i].Write, + } + } -// TransformExpressionsUp implements the sql.Node interface. -func (t *LockTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return t, nil + return &LockTables{t.Catalog, locks}, nil } // ErrTableNotLockable is returned whenever a lockable table can't be found. @@ -143,8 +137,6 @@ func NewUnlockTables() *UnlockTables { return new(UnlockTables) } -var _ sql.Node = (*UnlockTables)(nil) - // Children implements the sql.Node interface. func (t *UnlockTables) Children() []sql.Node { return nil } @@ -172,12 +164,11 @@ func (t *UnlockTables) String() string { return p.String() } -// TransformUp implements the sql.Node interface. -func (t *UnlockTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(t) -} +// WithChildren implements the Node interface. +func (t *UnlockTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (t *UnlockTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } diff --git a/sql/plan/naturaljoin.go b/sql/plan/naturaljoin.go index fe8ec7a8b..6ccf0182b 100644 --- a/sql/plan/naturaljoin.go +++ b/sql/plan/naturaljoin.go @@ -35,32 +35,11 @@ func (j NaturalJoin) String() string { return pr.String() } -// TransformUp implements the Node interface. -func (j *NaturalJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *NaturalJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) } - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewNaturalJoin(left, right)) -} - -// TransformExpressionsUp implements the Node interface. -func (j *NaturalJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewNaturalJoin(left, right), nil + return NewNaturalJoin(children[0], children[1]), nil } diff --git a/sql/plan/nothing.go b/sql/plan/nothing.go index 70d12c748..43792405d 100644 --- a/sql/plan/nothing.go +++ b/sql/plan/nothing.go @@ -7,8 +7,6 @@ var Nothing nothing type nothing struct{} -var _ sql.Node = nothing{} - func (nothing) String() string { return "NOTHING" } func (nothing) Resolved() bool { return true } func (nothing) Schema() sql.Schema { return nil } @@ -16,9 +14,12 @@ func (nothing) Children() []sql.Node { return nil } func (nothing) RowIter(*sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), nil } -func (nothing) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { - return Nothing, nil -} -func (nothing) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { + +// WithChildren implements the Node interface. +func (n nothing) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + return Nothing, nil } diff --git a/sql/plan/offset.go b/sql/plan/offset.go index a76cffacb..527b2c7cf 100644 --- a/sql/plan/offset.go +++ b/sql/plan/offset.go @@ -36,22 +36,12 @@ func (o *Offset) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &offsetIter{o.Offset, it}), nil } -// TransformUp implements the Transformable interface. -func (o *Offset) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := o.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewOffset(o.Offset, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (o *Offset) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := o.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (o *Offset) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1) } - return NewOffset(o.Offset, child), nil + return NewOffset(o.Offset, children[0]), nil } func (o Offset) String() string { diff --git a/sql/plan/process.go b/sql/plan/process.go index e565e425f..32a8452d6 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -21,28 +21,13 @@ func NewQueryProcess(node sql.Node, notify NotifyFunc) *QueryProcess { return &QueryProcess{UnaryNode{Child: node}, notify} } -// TransformUp implements the sql.Node interface. -func (p *QueryProcess) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - n, err := p.Child.TransformUp(f) - if err != nil { - return nil, err - } - - np := *p - np.Child = n - return &np, nil -} - -// TransformExpressionsUp implements the sql.Node interface. -func (p *QueryProcess) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - n, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *QueryProcess) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - np := *p - np.Child = n - return &np, nil + return NewQueryProcess(children[0], p.Notify), nil } // RowIter implements the sql.Node interface. diff --git a/sql/plan/processlist.go b/sql/plan/processlist.go index 8a362a65a..bc9f4d18c 100644 --- a/sql/plan/processlist.go +++ b/sql/plan/processlist.go @@ -58,13 +58,12 @@ func (p *ShowProcessList) Children() []sql.Node { return nil } // Resolved implements the Node interface. func (p *ShowProcessList) Resolved() bool { return true } -// TransformUp implements the Node interface. -func (p *ShowProcessList) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} +// WithChildren implements the Node interface. +func (p *ShowProcessList) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -// TransformExpressionsUp implements the Node interface. -func (p *ShowProcessList) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/project.go b/sql/plan/project.go index 99779e86b..8b166d449 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -69,30 +69,6 @@ func (p *Project) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &iter{p, i, ctx}), nil } -// TransformUp implements the Transformable interface. -func (p *Project) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewProject(p.Projections, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *Project) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - exprs, err := transformExpressionsUp(f, p.Projections) - if err != nil { - return nil, err - } - - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewProject(exprs, child), nil -} - func (p *Project) String() string { pr := sql.NewTreePrinter() var exprs = make([]string, len(p.Projections)) @@ -109,14 +85,22 @@ func (p *Project) Expressions() []sql.Expression { return p.Projections } -// TransformExpressions implements the Expressioner interface. -func (p *Project) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - projects, err := transformExpressionsUp(f, p.Projections) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Project) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + + return NewProject(p.Projections, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Project) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(p.Projections) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), len(p.Projections)) } - return NewProject(projects, p.Child), nil + return NewProject(exprs, p.Child), nil } type iter struct { diff --git a/sql/plan/resolved_table.go b/sql/plan/resolved_table.go index 2706739d2..dbbd689a1 100644 --- a/sql/plan/resolved_table.go +++ b/sql/plan/resolved_table.go @@ -44,13 +44,12 @@ func (t *ResolvedTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -// TransformUp implements the Transformable interface. -func (t *ResolvedTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewResolvedTable(t.Table)) -} +// WithChildren implements the Node interface. +func (t *ResolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (t *ResolvedTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } diff --git a/sql/plan/set.go b/sql/plan/set.go index 08f994663..9b8101675 100644 --- a/sql/plan/set.go +++ b/sql/plan/set.go @@ -42,41 +42,41 @@ func (s *Set) Resolved() bool { // Children implements the sql.Node interface. func (s *Set) Children() []sql.Node { return nil } -// TransformUp implements the sql.Node interface. -func (s *Set) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(s) -} +// WithChildren implements the Node interface. +func (s *Set) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformExpressions implements sql.Expressioner interface. -func (s *Set) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - return s.TransformExpressionsUp(f) + return s, nil } -// Expressions implements the sql.Expressioner interface. -func (s *Set) Expressions() []sql.Expression { - var exprs = make([]sql.Expression, len(s.Variables)) - for i, v := range s.Variables { - exprs[i] = v.Value +// WithExpressions implements the Expressioner interface. +func (s *Set) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(s.Variables) { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.Variables)) } - return exprs -} -// TransformExpressionsUp implements the sql.Node interface. -func (s *Set) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { var vars = make([]SetVariable, len(s.Variables)) for i, v := range s.Variables { - val, err := v.Value.TransformUp(f) - if err != nil { - return nil, err + vars[i] = SetVariable{ + Name: v.Name, + Value: exprs[i], } - - vars[i] = v - vars[i].Value = val } return NewSet(vars...), nil } +// Expressions implements the sql.Expressioner interface. +func (s *Set) Expressions() []sql.Expression { + var exprs = make([]sql.Expression, len(s.Variables)) + for i, v := range s.Variables { + exprs[i] = v.Value + } + return exprs +} + // RowIter implements the sql.Node interface. func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) { span, ctx := ctx.Span("plan.Set") diff --git a/sql/plan/show_collation.go b/sql/plan/show_collation.go index 4b01febb3..3c833b246 100644 --- a/sql/plan/show_collation.go +++ b/sql/plan/show_collation.go @@ -42,12 +42,11 @@ func (ShowCollation) RowIter(ctx *sql.Context) (sql.RowIter, error) { // Schema implements the sql.Node interface. func (ShowCollation) Schema() sql.Schema { return collationSchema } -// TransformUp implements the sql.Node interface. -func (ShowCollation) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(ShowCollation{}) -} +// WithChildren implements the Node interface. +func (s ShowCollation) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (ShowCollation) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return ShowCollation{}, nil + return s, nil } diff --git a/sql/plan/show_create_database.go b/sql/plan/show_create_database.go index 57ca9764b..96de06175 100644 --- a/sql/plan/show_create_database.go +++ b/sql/plan/show_create_database.go @@ -82,12 +82,11 @@ func (s *ShowCreateDatabase) Resolved() bool { return !ok } -// TransformExpressionsUp implements the sql.Node interface. -func (s *ShowCreateDatabase) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return s, nil -} +// WithChildren implements the Node interface. +func (s *ShowCreateDatabase) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformUp implements the sql.Node interface. -func (s *ShowCreateDatabase) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(s) + return s, nil } diff --git a/sql/plan/show_create_table.go b/sql/plan/show_create_table.go index 89c27f6f4..7b627b54a 100644 --- a/sql/plan/show_create_table.go +++ b/sql/plan/show_create_table.go @@ -25,14 +25,13 @@ func (n *ShowCreateTable) Schema() sql.Schema { } } -// TransformExpressionsUp implements the Transformable interface. -func (n *ShowCreateTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return n, nil -} +// WithChildren implements the Node interface. +func (n *ShowCreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } -// TransformUp implements the Transformable interface. -func (n *ShowCreateTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowCreateTable(n.CurrentDatabase, n.Catalog, n.Table)) + return n, nil } // RowIter implements the Node interface diff --git a/sql/plan/show_indexes.go b/sql/plan/show_indexes.go index b28a4cbc3..d427509ea 100644 --- a/sql/plan/show_indexes.go +++ b/sql/plan/show_indexes.go @@ -39,13 +39,12 @@ func (n *ShowIndexes) Resolved() bool { return !ok } -// TransformUp implements the Transformable interface. -func (n *ShowIndexes) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowIndexes(n.db, n.Table, n.Registry)) -} +// WithChildren implements the Node interface. +func (n *ShowIndexes) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (n *ShowIndexes) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return n, nil } diff --git a/sql/plan/show_tables.go b/sql/plan/show_tables.go index 7d44f73ae..930cb20b4 100644 --- a/sql/plan/show_tables.go +++ b/sql/plan/show_tables.go @@ -84,13 +84,12 @@ func (p *ShowTables) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *ShowTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowTables(p.db, p.Full)) -} +// WithChildren implements the Node interface. +func (p *ShowTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (p *ShowTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/showcolumns.go b/sql/plan/showcolumns.go index 88a6a63b6..d2094fe3f 100644 --- a/sql/plan/showcolumns.go +++ b/sql/plan/showcolumns.go @@ -104,24 +104,13 @@ func (s *ShowColumns) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, sql.RowsToRowIter(rows...)), nil } -// TransformUp creates a new ShowColumns node. -func (s *ShowColumns) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (s *ShowColumns) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return f(NewShowColumns(s.Full, child)) -} - -// TransformExpressionsUp creates a new ShowColumns node. -func (s *ShowColumns) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := s.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewShowColumns(s.Full, child), nil + return NewShowColumns(s.Full, children[0]), nil } func (s *ShowColumns) String() string { diff --git a/sql/plan/showdatabases.go b/sql/plan/showdatabases.go index f370d7eff..41873c7be 100644 --- a/sql/plan/showdatabases.go +++ b/sql/plan/showdatabases.go @@ -53,16 +53,15 @@ func (p *ShowDatabases) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *ShowDatabases) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - np := *p - return f(&np) -} +// WithChildren implements the Node interface. +func (p *ShowDatabases) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (p *ShowDatabases) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } + func (p ShowDatabases) String() string { return "ShowDatabases" } diff --git a/sql/plan/showtablestatus.go b/sql/plan/showtablestatus.go index b30bf912f..158a1d65c 100644 --- a/sql/plan/showtablestatus.go +++ b/sql/plan/showtablestatus.go @@ -86,13 +86,12 @@ func (s *ShowTableStatus) String() string { return fmt.Sprintf("ShowTableStatus(%s)", strings.Join(s.Databases, ", ")) } -// TransformUp implements the sql.Node interface. -func (s *ShowTableStatus) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(s) -} +// WithChildren implements the Node interface. +func (s *ShowTableStatus) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (s *ShowTableStatus) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return s, nil } diff --git a/sql/plan/showvariables.go b/sql/plan/showvariables.go index 7c4648d5e..8cb7fa0c1 100644 --- a/sql/plan/showvariables.go +++ b/sql/plan/showvariables.go @@ -28,13 +28,12 @@ func (sv *ShowVariables) Resolved() bool { return true } -// TransformUp implements the sq.Transformable interface. -func (sv *ShowVariables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowVariables(sv.config, sv.pattern)) -} +// WithChildren implements the Node interface. +func (sv *ShowVariables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(sv, len(children), 0) + } -// TransformExpressionsUp implements the sql.Transformable interface. -func (sv *ShowVariables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return sv, nil } diff --git a/sql/plan/showwarnings.go b/sql/plan/showwarnings.go index 5de5fb1f7..c990bfc81 100644 --- a/sql/plan/showwarnings.go +++ b/sql/plan/showwarnings.go @@ -12,13 +12,12 @@ func (ShowWarnings) Resolved() bool { return true } -// TransformUp implements the sq.Transformable interface. -func (sw ShowWarnings) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(sw) -} +// WithChildren implements the Node interface. +func (sw ShowWarnings) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(sw, len(children), 0) + } -// TransformExpressionsUp implements the sql.Transformable interface. -func (sw ShowWarnings) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return sw, nil } diff --git a/sql/plan/sort.go b/sql/plan/sort.go index d6836ba40..8235b2386 100644 --- a/sql/plan/sort.go +++ b/sql/plan/sort.go @@ -6,8 +6,8 @@ import ( "sort" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) // ErrUnableSort is thrown when something happens on sorting @@ -91,34 +91,6 @@ func (s *Sort) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newSortIter(s, i)), nil } -// TransformUp implements the Transformable interface. -func (s *Sort) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewSort(s.SortFields, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (s *Sort) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - var sfs = make([]SortField, len(s.SortFields)) - for i, sf := range s.SortFields { - col, err := sf.Column.TransformUp(f) - if err != nil { - return nil, err - } - sfs[i] = SortField{col, sf.Order, sf.NullOrdering} - } - - child, err := s.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewSort(sfs, child), nil -} - func (s *Sort) String() string { pr := sql.NewTreePrinter() var fields = make([]string, len(s.SortFields)) @@ -139,22 +111,31 @@ func (s *Sort) Expressions() []sql.Expression { return exprs } -// TransformExpressions implements the Expressioner interface. -func (s *Sort) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - var sortFields = make([]SortField, len(s.SortFields)) - for i, field := range s.SortFields { - transformed, err := field.Column.TransformUp(f) - if err != nil { - return nil, err - } - sortFields[i] = SortField{ - Column: transformed, - Order: field.Order, - NullOrdering: field.NullOrdering, +// WithChildren implements the Node interface. +func (s *Sort) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + + return NewSort(s.SortFields, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (s *Sort) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(s.SortFields) { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.SortFields)) + } + + var fields = make([]SortField, len(s.SortFields)) + for i, expr := range exprs { + fields[i] = SortField{ + Column: expr, + NullOrdering: s.SortFields[i].NullOrdering, + Order: s.SortFields[i].Order, } } - return NewSort(sortFields, s.Child), nil + return NewSort(fields, s.Child), nil } type sortIter struct { diff --git a/sql/plan/subqueryalias.go b/sql/plan/subqueryalias.go index 70233a075..da1264c88 100644 --- a/sql/plan/subqueryalias.go +++ b/sql/plan/subqueryalias.go @@ -45,16 +45,22 @@ func (n *SubqueryAlias) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, iter), nil } -// TransformUp implements the Node interface. -func (n *SubqueryAlias) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(n) -} +// WithChildren implements the Node interface. +func (n *SubqueryAlias) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 1) + } -// TransformExpressionsUp implements the Node interface. -func (n *SubqueryAlias) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + nn := *n + nn.Child = children[0] return n, nil } +// Opaque implements the OpaqueNode interface. +func (n *SubqueryAlias) Opaque() bool { + return true +} + func (n SubqueryAlias) String() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("SubqueryAlias(%s)", n.name) diff --git a/sql/plan/tablealias.go b/sql/plan/tablealias.go index 02795b177..be37fb109 100644 --- a/sql/plan/tablealias.go +++ b/sql/plan/tablealias.go @@ -23,22 +23,13 @@ func (t *TableAlias) Name() string { return t.name } -// TransformUp implements the Transformable interface. -func (t *TableAlias) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (t *TableAlias) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewTableAlias(t.name, child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (t *TableAlias) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := t.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewTableAlias(t.name, child), nil + return NewTableAlias(t.name, children[0]), nil } // RowIter implements the Node interface. diff --git a/sql/plan/transaction.go b/sql/plan/transaction.go index 6565cffb2..3d09366ef 100644 --- a/sql/plan/transaction.go +++ b/sql/plan/transaction.go @@ -15,13 +15,12 @@ func (*Rollback) RowIter(*sql.Context) (sql.RowIter, error) { func (*Rollback) String() string { return "ROLLBACK" } -// TransformUp implements the sql.Node interface. -func (r *Rollback) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(r) -} +// WithChildren implements the Node interface. +func (r *Rollback) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (r *Rollback) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return r, nil } diff --git a/sql/plan/transform.go b/sql/plan/transform.go new file mode 100644 index 000000000..e437327ed --- /dev/null +++ b/sql/plan/transform.go @@ -0,0 +1,89 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// TransformUp applies a transformation function to the given tree from the +// bottom up. +func TransformUp(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) { + if o, ok := node.(sql.OpaqueNode); ok && o.Opaque() { + return f(node) + } + + children := node.Children() + if len(children) == 0 { + return f(node) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(node) +} + +// TransformExpressionsUp applies a transformation function to all expressions +// on the given tree from the bottom up. +func TransformExpressionsUp(node sql.Node, f sql.TransformExprFunc) (sql.Node, error) { + if o, ok := node.(sql.OpaqueNode); ok && o.Opaque() { + return TransformExpressions(node, f) + } + + children := node.Children() + if len(children) == 0 { + return TransformExpressions(node, f) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformExpressionsUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return TransformExpressions(node, f) +} + +// TransformExpressions applies a transformation function to all expressions +// on the given node. +func TransformExpressions(node sql.Node, f sql.TransformExprFunc) (sql.Node, error) { + e, ok := node.(sql.Expressioner) + if !ok { + return node, nil + } + + exprs := e.Expressions() + if len(exprs) == 0 { + return node, nil + } + + newExprs := make([]sql.Expression, len(exprs)) + for i, e := range exprs { + e, err := expression.TransformUp(e, f) + if err != nil { + return nil, err + } + newExprs[i] = e + } + + return e.WithExpressions(newExprs...) +} diff --git a/sql/plan/transform_test.go b/sql/plan/transform_test.go index 88a6cc89a..730ab128f 100644 --- a/sql/plan/transform_test.go +++ b/sql/plan/transform_test.go @@ -24,7 +24,7 @@ func TestTransformUp(t *testing.T) { } table := mem.NewTable("resolved", schema) - pt, err := p.TransformUp(func(n sql.Node) (sql.Node, error) { + pt, err := TransformUp(p, func(n sql.Node) (sql.Node, error) { switch n.(type) { case *UnresolvedTable: return NewResolvedTable(table), nil diff --git a/sql/plan/unresolved.go b/sql/plan/unresolved.go index 84160804d..9bf70a639 100644 --- a/sql/plan/unresolved.go +++ b/sql/plan/unresolved.go @@ -3,8 +3,8 @@ package plan import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrUnresolvedTable is thrown when a table cannot be resolved @@ -42,13 +42,12 @@ func (*UnresolvedTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, ErrUnresolvedTable.New() } -// TransformUp implements the Transformable interface. -func (t *UnresolvedTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewUnresolvedTable(t.name, t.Database)) -} +// WithChildren implements the Node interface. +func (t *UnresolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (t *UnresolvedTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } diff --git a/sql/plan/use.go b/sql/plan/use.go index fb721f8a4..b7eb517b7 100644 --- a/sql/plan/use.go +++ b/sql/plan/use.go @@ -50,13 +50,12 @@ func (u *Use) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), nil } -// TransformUp implements the sql.Node interface. -func (u *Use) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(u) -} +// WithChildren implements the Node interface. +func (u *Use) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) + } -// TransformExpressionsUp implements the sql.Node interface. -func (u *Use) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return u, nil } diff --git a/sql/plan/values.go b/sql/plan/values.go index d07eb5bff..ea1d5be39 100644 --- a/sql/plan/values.go +++ b/sql/plan/values.go @@ -76,25 +76,6 @@ func (p *Values) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *Values) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *Values) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - ets := make([][]sql.Expression, len(p.ExpressionTuples)) - var err error - for i, et := range p.ExpressionTuples { - ets[i], err = transformExpressionsUp(f, et) - if err != nil { - return nil, err - } - } - - return NewValues(ets), nil -} - func (p *Values) String() string { return fmt.Sprintf("Values(%d tuples)", len(p.ExpressionTuples)) } @@ -108,15 +89,33 @@ func (p *Values) Expressions() []sql.Expression { return exprs } -// TransformExpressions implements the Expressioner interface. -func (p *Values) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - tuples := [][]sql.Expression{} - for _, tuple := range p.ExpressionTuples { - transformed, err := transformExpressionsUp(f, tuple) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Values) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Values) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + var expected int + for _, t := range p.ExpressionTuples { + expected += len(t) + } + + if len(exprs) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), expected) + } + + var offset int + var tuples = make([][]sql.Expression, len(p.ExpressionTuples)) + for i, t := range p.ExpressionTuples { + for range t { + tuples[i] = append(tuples[i], exprs[offset]) + offset++ } - tuples = append(tuples, transformed) } return NewValues(tuples), nil diff --git a/sql/session_test.go b/sql/session_test.go index 3a109779f..315015ec5 100644 --- a/sql/session_test.go +++ b/sql/session_test.go @@ -51,27 +51,22 @@ func TestHasDefaultValue(t *testing.T) { type testNode struct{} -func (t *testNode) Resolved() bool { +func (*testNode) Resolved() bool { panic("not implemented") } - -func (t *testNode) TransformUp(func(Node) Node) Node { - panic("not implemented") -} - -func (t *testNode) TransformExpressionsUp(func(Expression) Expression) Node { +func (*testNode) WithChildren(...Node) (Node, error) { panic("not implemented") } -func (t *testNode) Schema() Schema { +func (*testNode) Schema() Schema { panic("not implemented") } -func (t *testNode) Children() []Node { +func (*testNode) Children() []Node { panic("not implemented") } -func (t *testNode) RowIter(ctx *Context) (RowIter, error) { +func (*testNode) RowIter(ctx *Context) (RowIter, error) { return newTestNodeIterator(ctx), nil }