Skip to content
Closed
7 changes: 7 additions & 0 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,13 @@ func TestInsertScriptsPrepared(t *testing.T, harness Harness) {
}
}

func TestGeneratedColumns(t *testing.T, harness Harness) {
harness.Setup(setup.MydbData)
for _, script := range queries.GeneratedColumnTests {
TestScriptPrepared(t, harness, script)
}
}

func TestComplexIndexQueriesPrepared(t *testing.T, harness Harness) {
harness.Setup(setup.ComplexIndexSetup...)
e := mustNewEngine(t, harness)
Expand Down
2 changes: 0 additions & 2 deletions enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,6 @@ func injectBindVarsAndPrepare(

buf := sqlparser.NewTrackedBuffer(nil)
parsed.Format(buf)
println(q)
println(buf.String())
e.PreparedDataCache.CacheStmt(ctx.Session.ID(), buf.String(), parsed)

_, isDatabaser := resPlan.(sql.Databaser)
Expand Down
61 changes: 61 additions & 0 deletions enginetest/queries/generated_columns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package queries

import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

var GeneratedColumnTests = []ScriptTest{
{
Name: "stored generated column",
SetUpScript: []string{
"create table t1 (a int primary key, b int as (a + 1) stored)",
},
Assertions: []ScriptTestAssertion{
{
Query: "show create table t1",
// TODO: double parens here is a bug
Expected: []sql.Row{{"t1",
"CREATE TABLE `t1` (\n" +
" `a` int NOT NULL,\n" +
" `b` int GENERATED ALWAYS AS ((a + 1)) STORED,\n" +
" PRIMARY KEY (`a`)\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t1 values (1,2)",
ExpectedErr: sql.ErrGeneratedColumnValue,
},
{
Query: "insert into t1(a,b) values (1,2)",
ExpectedErr: sql.ErrGeneratedColumnValue,
},
{
Query: "select * from t1 order by a",
Expected: []sql.Row{},
},
{
Query: "insert into t1(a) values (1), (2), (3)",
Expected: []sql.Row{{types.NewOkResult(3)}},
},
{
Query: "select * from t1 order by a",
Expected: []sql.Row{{1, 2}, {2, 3}, {3, 4}},
},
},
},
}
16,004 changes: 7,703 additions & 8,301 deletions enginetest/queries/query_plans.go

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions sql/analyzer/experimental_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ import (
// indexes.
func fixupAuxiliaryExprs(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
return transform.NodeWithOpaque(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
if _, ok := n.(*plan.Union); ok {
print("")
}
switch n := n.(type) {
default:
ret, same1, err := fixidx.FixFieldIndexesForExpressions(ctx, a.LogFn(), n, scope)
Expand Down
31 changes: 23 additions & 8 deletions sql/analyzer/inserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package analyzer

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -28,23 +29,33 @@ import (
func setInsertColumns(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
// We capture all INSERTs along the tree, such as those inside of block statements.
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
// TODO: put load data here too?
ii, ok := n.(*plan.InsertInto)
if !ok {
return n, transform.SameTree, nil
}

if !ii.Destination.Resolved() {
destination := ii.Destination
if !destination.Resolved() {
return n, transform.SameTree, nil
}

schema := ii.Destination.Schema()
nameable, ok := destination.(sql.Nameable)
if !ok {
return n, transform.SameTree, fmt.Errorf("expected a sql.Nameable, got %T", destination)
}

schema := destination.Schema()

// If no column names were specified in the query, go ahead and fill
// them all in now that the destination is resolved.
// TODO: setting the plan field directly is not great
if len(ii.ColumnNames) == 0 {
colNames := make([]string, len(schema))
for i, col := range schema {
// Tables with any generated columns must specify a column list, so this is always an error
if col.Generated != nil {
return nil, transform.SameTree, sql.ErrGeneratedColumnValue.New(col.Name, nameable.Name())
}
colNames[i] = col.Name
}
ii.ColumnNames = colNames
Expand Down Expand Up @@ -117,7 +128,7 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
columnNames[i] = f.Name
}
} else {
err = validateColumns(columnNames, dstSchema)
err = validateColumns(table.Name(), columnNames, dstSchema)
if err != nil {
return nil, transform.SameTree, err
}
Expand Down Expand Up @@ -227,16 +238,20 @@ func wrapRowSource(ctx *sql.Context, scope *plan.Scope, logFn func(string, ...an
return plan.NewProject(projExprs, insertSource), nil
}

func validateColumns(columnNames []string, dstSchema sql.Schema) error {
dstColNames := make(map[string]struct{})
func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema) error {
dstColNames := make(map[string]*sql.Column)
for _, dstCol := range dstSchema {
dstColNames[strings.ToLower(dstCol.Name)] = struct{}{}
dstColNames[strings.ToLower(dstCol.Name)] = dstCol
}
usedNames := make(map[string]struct{})
for _, columnName := range columnNames {
if _, exists := dstColNames[columnName]; !exists {
dstCol, exists := dstColNames[columnName]
if !exists {
return plan.ErrInsertIntoNonexistentColumn.New(columnName)
}
if dstCol.Generated != nil {
return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName)
}
if _, exists := usedNames[columnName]; !exists {
usedNames[columnName] = struct{}{}
} else {
Expand Down
6 changes: 6 additions & 0 deletions sql/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type Column struct {
// Type is the data type of the column.
Type Type
// Default contains the default value of the column or nil if it was not explicitly defined. A nil instance is valid, thus calls do not error.
// TODO: figure out where table values are getting filled in
// TODO: can a column have both a default and generated value?
Default *ColumnDefaultValue
// AutoIncrement is true if the column auto-increments.
AutoIncrement bool
Expand All @@ -47,6 +49,10 @@ type Column struct {
Comment string
// Extra contains any additional information to put in the `extra` column under `information_schema.columns`.
Extra string
// Generated is non-nil if the column is defined with a generated value
Generated *ColumnDefaultValue
// Virtual is true if the column is defined as a virtual column. Generated must be non-nil in this case.
Virtual bool
}

// Check ensures the value is correct for this column.
Expand Down
3 changes: 3 additions & 0 deletions sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,9 @@ var (

// ErrFullTextInvalidColumnType is returned when a Full-Text index is declared on a non-text column.
ErrFullTextInvalidColumnType = errors.NewKind("all Full-Text columns must be declared on a non-binary text type")

// ErrGeneratedColumnValue is returned when a value is provided for a generated column
ErrGeneratedColumnValue = errors.NewKind("The value specified for generated column %q in table %q is not allowed.")
)

// CastSQLError returns a *mysql.SQLError with the error code and in some cases, also a SQL state, populated for the
Expand Down
13 changes: 12 additions & 1 deletion sql/expression/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ type Alias struct {
UnaryExpression
name string
unreferencable bool
id sql.ColumnId
}

var _ sql.Expression = (*Alias)(nil)
var _ sql.CollationCoercible = (*Alias)(nil)

// NewAlias returns a new Alias node.
func NewAlias(name string, expr sql.Expression) *Alias {
return &Alias{UnaryExpression{expr}, name, false}
return &Alias{UnaryExpression{expr}, name, false, 0}
}

// AsUnreferencable marks the alias outside of scope referencing
Expand All @@ -104,6 +105,16 @@ func (e *Alias) Unreferencable() bool {
return e.unreferencable
}

func (e *Alias) WithId(id sql.ColumnId) *Alias {
ret := *e
ret.id = id
return &ret
}

func (e *Alias) Id() sql.ColumnId {
return e.id
}

// Type returns the type of the expression.
func (e *Alias) Type() sql.Type {
return e.Child.Type()
Expand Down
19 changes: 16 additions & 3 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,30 @@ func (ii *InsertInto) WithDatabase(database sql.Database) (sql.Node, error) {
return &nc, nil
}

func (ii InsertInto) WithColumnNames(cols []string) *InsertInto {
ii.ColumnNames = cols
return &ii
}

// InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be
// overridden. This is useful when the table in question has late-resolving column defaults.
type InsertDestination struct {
UnaryNode
Sch sql.Schema
DestinationName string
Sch sql.Schema
}

var _ sql.Node = (*InsertDestination)(nil)
var _ sql.Nameable = (*InsertDestination)(nil)
var _ sql.Expressioner = (*InsertDestination)(nil)
var _ sql.CollationCoercible = (*InsertDestination)(nil)

func NewInsertDestination(schema sql.Schema, node sql.Node) *InsertDestination {
nameable := node.(sql.Nameable)
return &InsertDestination{
UnaryNode: UnaryNode{Child: node},
Sch: schema,
UnaryNode: UnaryNode{Child: node},
Sch: schema,
DestinationName: nameable.Name(),
}
}

Expand All @@ -149,6 +158,10 @@ func (id InsertDestination) WithExpressions(exprs ...sql.Expression) (sql.Node,
return &id, nil
}

func (id *InsertDestination) Name() string {
return id.DestinationName
}

func (id *InsertDestination) String() string {
return id.UnaryNode.Child.String()
}
Expand Down
53 changes: 30 additions & 23 deletions sql/planbuilder/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,18 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
}
}
var aliases []sql.Expression
for _, e := range projScope.cols {
for _, col := range projScope.cols {
// eval aliases in project scope
switch e := col.scalar.(type) {
case *expression.Alias:
if !e.Unreferencable() {
aliases = append(aliases, e.WithId(sql.ColumnId(col.id)))
}
default:
}

// projection dependencies -> table cols needed above
transform.InspectExpr(e.scalar, func(e sql.Expression) bool {
transform.InspectExpr(col.scalar, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.GetField:
colName := strings.ToLower(e.Name())
Expand All @@ -174,10 +183,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
selectGfs = append(selectGfs, e)
selectStr[colName] = true
}
case *expression.Alias:
if !e.Unreferencable() {
aliases = append(aliases, e)
}
default:
}
return false
})
Expand Down Expand Up @@ -467,9 +473,18 @@ func (b *Builder) buildWindow(fromScope, projScope *scope) *scope {
}
}
var aliases []sql.Expression
for _, e := range projScope.cols {
for _, col := range projScope.cols {
// eval aliases in project scope
switch e := col.scalar.(type) {
case *expression.Alias:
if !e.Unreferencable() {
aliases = append(aliases, e.WithId(sql.ColumnId(col.id)))
}
default:
}

// projection dependencies -> table cols needed above
transform.InspectExpr(e.scalar, func(e sql.Expression) bool {
transform.InspectExpr(col.scalar, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.GetField:
colName := strings.ToLower(e.Name())
Expand All @@ -478,11 +493,7 @@ func (b *Builder) buildWindow(fromScope, projScope *scope) *scope {
selectStr[colName] = true
selectGfs = append(selectGfs, e)
}
case *expression.Alias:
// selection aliases need to be projected
if !e.Unreferencable() {
aliases = append(aliases, e)
}
default:
}
return false
})
Expand Down Expand Up @@ -698,17 +709,13 @@ func (b *Builder) buildInnerProj(fromScope, projScope *scope) *scope {
}

// eval aliases in project scope
for _, e := range projScope.cols {
// selection aliases need to be projected
transform.InspectExpr(e.scalar, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.Alias:
if !e.Unreferencable() {
proj = append(proj, e)
}
for _, col := range projScope.cols {
switch e := col.scalar.(type) {
case *expression.Alias:
if !e.Unreferencable() {
proj = append(proj, e.WithId(sql.ColumnId(col.id)))
}
return false
})
}
}

if len(proj) > 0 {
Expand Down
Loading