diff --git a/engine_test.go b/engine_test.go index 5e9b9964b..e7a0f64d3 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1036,6 +1036,22 @@ var queries = []struct { {string("first row"), int64(1)}, }, }, + { + "SELECT CONVERT('9999-12-31 23:59:59', DATETIME)", + []sql.Row{{time.Date(9999, time.December, 31, 23, 59, 59, 0, time.UTC)}}, + }, + { + "SELECT CONVERT('10000-12-31 23:59:59', DATETIME)", + []sql.Row{{nil}}, + }, + { + "SELECT '9999-12-31 23:59:59' + INTERVAL 1 DAY", + []sql.Row{{nil}}, + }, + { + "SELECT DATE_ADD('9999-12-31 23:59:59', INTERVAL 1 DAY)", + []sql.Row{{nil}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go new file mode 100644 index 000000000..72b59f721 --- /dev/null +++ b/sql/analyzer/convert_dates.go @@ -0,0 +1,40 @@ +package analyzer + +import ( + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function" +) + +// convertDates wraps all expressions of date and datetime type with converts +// to ensure the date range is validated. +func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + if !n.Resolved() { + return n, nil + } + + return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { + // No need to wrap expressions that already validate times, such as + // convert, date_add, etc and those expressions whose Type method + // cannot be called because they are placeholders. + switch e.(type) { + case *expression.Convert, + *expression.Arithmetic, + *function.DateAdd, + *function.DateSub, + *expression.Star, + *expression.DefaultColumn, + *expression.Alias: + return e, nil + default: + switch e.Type() { + case sql.Date: + return expression.NewConvert(e, expression.ConvertToDate), nil + case sql.Timestamp: + return expression.NewConvert(e, expression.ConvertToDatetime), nil + default: + return e, nil + } + } + }) +} diff --git a/sql/analyzer/convert_dates_test.go b/sql/analyzer/convert_dates_test.go new file mode 100644 index 000000000..3989b5246 --- /dev/null +++ b/sql/analyzer/convert_dates_test.go @@ -0,0 +1,157 @@ +package analyzer + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/mem" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +func TestConvertDates(t *testing.T) { + testCases := []struct { + name string + in sql.Expression + out sql.Expression + }{ + { + "arithmetic with dates", + expression.NewPlus(expression.NewLiteral("", sql.Timestamp), expression.NewLiteral("", sql.Timestamp)), + expression.NewPlus( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + ), + }, + { + "star", + expression.NewStar(), + expression.NewStar(), + }, + { + "default column", + expression.NewDefaultColumn("foo"), + expression.NewDefaultColumn("foo"), + }, + { + "convert to date", + expression.NewConvert( + expression.NewPlus( + expression.NewLiteral("", sql.Timestamp), + expression.NewLiteral("", sql.Timestamp), + ), + expression.ConvertToDatetime, + ), + expression.NewConvert( + expression.NewPlus( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + ), + expression.ConvertToDatetime, + ), + }, + { + "convert to other type", + expression.NewConvert( + expression.NewLiteral("", sql.Text), + expression.ConvertToBinary, + ), + expression.NewConvert( + expression.NewLiteral("", sql.Text), + expression.ConvertToBinary, + ), + }, + { + "datetime col in alias", + expression.NewAlias( + expression.NewLiteral("", sql.Timestamp), + "foo", + ), + expression.NewAlias( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + "foo", + ), + }, + { + "date col in alias", + expression.NewAlias( + expression.NewLiteral("", sql.Date), + "foo", + ), + expression.NewAlias( + expression.NewConvert( + expression.NewLiteral("", sql.Date), + expression.ConvertToDate, + ), + "foo", + ), + }, + { + "date add", + newDateAdd( + expression.NewLiteral("", sql.Timestamp), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + newDateAdd( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + }, + { + "date sub", + newDateSub( + expression.NewLiteral("", sql.Timestamp), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + newDateSub( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + }, + } + + table := plan.NewResolvedTable(mem.NewTable("t", nil)) + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + input := plan.NewProject([]sql.Expression{tt.in}, table) + expected := plan.NewProject([]sql.Expression{tt.out}, table) + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) + }) + } +} + +func newDateAdd(l, r sql.Expression) sql.Expression { + e, _ := function.NewDateAdd(l, r) + return e +} + +func newDateSub(l, r sql.Expression) sql.Expression { + e, _ := function.NewDateSub(l, r) + return e +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 2b0b46aec..0a9112384 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -20,6 +20,7 @@ var DefaultRules = []Rule{ {"reorder_projection", reorderProjection}, {"move_join_conds_to_filter", moveJoinConditionsToFilter}, {"eval_filter", evalFilter}, + {"convert_dates", convertDates}, {"optimize_distinct", optimizeDistinct}, } diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index a0ccfe2e5..1c300b52d 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -89,6 +89,15 @@ func (a *Arithmetic) String() string { return fmt.Sprintf("%s %s %s", a.Left, a.Op, a.Right) } +// IsNullable implements the sql.Expression interface. +func (a *Arithmetic) IsNullable() bool { + if a.Type() == sql.Timestamp { + return true + } + + return a.BinaryExpression.IsNullable() +} + // Type returns the greatest type for given operation. func (a *Arithmetic) Type() sql.Type { switch a.Op { @@ -254,12 +263,12 @@ func plus(lval, rval interface{}) (interface{}, error) { case time.Time: switch r := rval.(type) { case *TimeDelta: - return r.Add(l), nil + return sql.ValidateTime(r.Add(l)), nil } case *TimeDelta: switch r := rval.(type) { case time.Time: - return l.Add(r), nil + return sql.ValidateTime(l.Add(r)), nil } } @@ -288,7 +297,7 @@ func minus(lval, rval interface{}) (interface{}, error) { case time.Time: switch r := rval.(type) { case *TimeDelta: - return r.Sub(l), nil + return sql.ValidateTime(r.Sub(l)), nil } } diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 28012d5b4..0d2502133 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -51,6 +51,16 @@ func NewConvert(expr sql.Expression, castToType string) *Convert { } } +// IsNullable implements the Expression interface. +func (c *Convert) IsNullable() bool { + switch c.castToType { + case ConvertToDate, ConvertToDatetime: + return true + default: + return c.Child.IsNullable() + } +} + // Type implements the Expression interface. func (c *Convert) Type() sql.Type { switch c.castToType { @@ -58,8 +68,10 @@ func (c *Convert) Type() sql.Type { return sql.Blob case ConvertToChar, ConvertToNChar: return sql.Text - case ConvertToDate, ConvertToDatetime: + case ConvertToDate: return sql.Date + case ConvertToDatetime: + return sql.Timestamp case ConvertToDecimal: return sql.Float64 case ConvertToJSON: @@ -143,7 +155,7 @@ func convertValue(val interface{}, castTo string) (interface{}, error) { } } - return d, nil + return sql.ValidateTime(d.(time.Time)), nil case ConvertToDecimal: d, err := cast.ToFloat64E(val) if err != nil { diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go index 116eaead6..fd1a9970c 100644 --- a/sql/expression/function/date.go +++ b/sql/expression/function/date.go @@ -40,7 +40,7 @@ func (d *DateAdd) Resolved() bool { // IsNullable implements the sql.Expression interface. func (d *DateAdd) IsNullable() bool { - return d.Date.IsNullable() || d.Interval.IsNullable() + return true } // Type implements the sql.Expression interface. @@ -85,7 +85,7 @@ func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - return delta.Add(date.(time.Time)), nil + return sql.ValidateTime(delta.Add(date.(time.Time))), nil } func (d *DateAdd) String() string { @@ -124,7 +124,7 @@ func (d *DateSub) Resolved() bool { // IsNullable implements the sql.Expression interface. func (d *DateSub) IsNullable() bool { - return d.Date.IsNullable() || d.Interval.IsNullable() + return true } // Type implements the sql.Expression interface. @@ -169,7 +169,7 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - return delta.Sub(date.(time.Time)), nil + return sql.ValidateTime(delta.Sub(date.(time.Time))), nil } func (d *DateSub) String() string { diff --git a/sql/type.go b/sql/type.go index 7ad54e8d5..238d400dc 100644 --- a/sql/type.go +++ b/sql/type.go @@ -151,6 +151,17 @@ type Type interface { fmt.Stringer } +var maxTime = time.Date(9999, time.December, 31, 23, 59, 59, 0, time.UTC) + +// ValidateTime receives a time and returns either that time or nil if it's +// not a valid time. +func ValidateTime(t time.Time) interface{} { + if t.After(maxTime) { + return nil + } + return t +} + var ( // Null represents the null type. Null nullT