Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions enginetest/insert_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ var InsertQueries = []WriteQueryTest{
{int64(14), "third row new"},
},
},
{
"INSERT INTO mytable (i,s) values (1, 'hello') ON DUPLICATE KEY UPDATE s='hello'",
[]sql.Row{{sql.NewOkResult(2)}},
"SELECT * FROM mytable WHERE i = 1",
[]sql.Row{{int64(1), "hello"}},
},
{
"INSERT INTO mytable (i,s) values (10, 'hello') ON DUPLICATE KEY UPDATE s='hello'",
[]sql.Row{{sql.NewOkResult(1)}},
"SELECT * FROM mytable WHERE i = 10",
[]sql.Row{{int64(10), "hello"}},
},
}

var InsertErrorTests = []GenericErrorQueryTest{
Expand Down
6 changes: 4 additions & 2 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,9 @@ func convertDropView(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) {
}

func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) {
if len(i.OnDup) > 0 {
return nil, ErrUnsupportedFeature.New("ON DUPLICATE KEY")
onDupExprs, err := updateExprsToExpressions(ctx, sqlparser.UpdateExprs(i.OnDup))
if err != nil {
return nil, err
}

if len(i.Ignore) > 0 {
Expand All @@ -864,6 +865,7 @@ func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) {
src,
isReplace,
columnsToStrings(i.Columns),
onDupExprs,
), nil
}

Expand Down
3 changes: 3 additions & 0 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,7 @@ var fixtures = map[string]sql.Node{
}}),
false,
[]string{"col1", "col2"},
[]sql.Expression{},
),
`REPLACE INTO t1 (col1, col2) VALUES ('a', 1)`: plan.NewInsertInto(
plan.NewUnresolvedTable("t1", ""),
Expand All @@ -1019,6 +1020,7 @@ var fixtures = map[string]sql.Node{
}}),
true,
[]string{"col1", "col2"},
[]sql.Expression{},
),
`SHOW TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), false, nil),
`SHOW FULL TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), true, nil),
Expand Down Expand Up @@ -2151,6 +2153,7 @@ var fixtures = map[string]sql.Node{
),
false,
[]string{"a", "b"},
[]sql.Expression{},
),
},
), "",
Expand Down
78 changes: 74 additions & 4 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ type InsertInto struct {
BinaryNode
ColumnNames []string
IsReplace bool
OnDupExprs []sql.Expression
}

// NewInsertInto creates an InsertInto node.
func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string) *InsertInto {
func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string, onDupExprs []sql.Expression) *InsertInto {
return &InsertInto{
BinaryNode: BinaryNode{Left: dst, Right: src},
ColumnNames: cols,
IsReplace: isReplace,
OnDupExprs: onDupExprs,
}
}

Expand Down Expand Up @@ -186,8 +188,35 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
}
} else {
if err := inserter.Insert(ctx, row); err != nil {
_ = iter.Close()
return i, err
if !sql.ErrUniqueKeyViolation.Is(err) || len(p.OnDupExprs) <= 0 {
_ = iter.Close()
return i, err
}

// ON DUPLICATE KEY UPDATE ...
// build expression for filtering the update node
var pkExpression sql.Expression
for i, colName := range p.ColumnNames {
for index, col := range p.Left.Schema() {
if col.Name == colName && col.PrimaryKey {
if v, ok := p.Right.(*Values); ok {
value := v.Expressions()[i]
exp := expression.NewEquals(expression.NewGetField(index, col.Type, col.Name, col.Nullable), value)
if pkExpression != nil {
pkExpression = expression.NewAnd(pkExpression, exp)
} else {
pkExpression = exp
}
}
}
}
}

update := NewUpdate(NewFilter(pkExpression, p.Left), p.OnDupExprs)
_, i, err = update.Execute(ctx)
if err != nil {
return i, err
}
}
}
i++
Expand Down Expand Up @@ -243,7 +272,7 @@ func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) {
return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2)
}

return NewInsertInto(children[0], children[1], p.IsReplace, p.ColumnNames), nil
return NewInsertInto(children[0], children[1], p.IsReplace, p.ColumnNames, p.OnDupExprs), nil
}

func (p InsertInto) String() string {
Expand Down Expand Up @@ -341,3 +370,44 @@ func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) erro
}
return nil
}

func (p *InsertInto) applyUpdates(ctx *sql.Context, row sql.Row) (sql.Row, error) {
var ok bool
prev := row
for _, updateExpr := range p.OnDupExprs {
val, err := updateExpr.Eval(ctx, prev)
if err != nil {
return nil, err
}
prev, ok = val.(sql.Row)
if !ok {
return nil, ErrUpdateUnexpectedSetResult.New(val)
}
}
return prev, nil
}

func (p *InsertInto) Expressions() []sql.Expression {
return p.OnDupExprs
}

func (p *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
if len(newExprs) != len(p.OnDupExprs) {
return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.OnDupExprs), 1)
}

return NewInsertInto(p.Left, p.Right, p.IsReplace, p.ColumnNames, newExprs), nil
}

// Resolved implements the Resolvable interface.
func (p *InsertInto) Resolved() bool {
if !p.Left.Resolved() {
return false
}
for _, updateExpr := range p.OnDupExprs {
if !updateExpr.Resolved() {
return false
}
}
return true
}