diff --git a/driver/e2e_test.go b/driver/e2e_test.go deleted file mode 100644 index cd9f3975e4..0000000000 --- a/driver/e2e_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package driver_test - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestQuery(t *testing.T) { - mtb, records := personMemTable("db", "person") - db := sqlOpen(t, mtb, t.Name()+"?jsonAs=object") - - var name, email string - var numbers interface{} - var created time.Time - var count int - - cases := []struct { - Name, Query string - Pointers Pointers - Expect Records - }{ - {"Select All", "SELECT * FROM db.person", []V{&name, &email, &numbers, &created}, records}, - {"Select First", "SELECT * FROM db.person LIMIT 1", []V{&name, &email, &numbers, &created}, records.Rows(0)}, - {"Select Name", "SELECT name FROM db.person", []V{&name}, records.Columns(0)}, - {"Select Count", "SELECT COUNT(1) FROM db.person", []V{&count}, Records{{len(records)}}}, - - {"Insert", `INSERT INTO db.person VALUES ('foo', 'bar', '["baz"]', NOW())`, []V{}, Records{}}, - {"Select Inserted", "SELECT name, email, phone_numbers FROM db.person WHERE name = 'foo'", []V{&name, &email, &numbers}, Records{{"foo", "bar", []V{"baz"}}}}, - - {"Update", "UPDATE db.person SET name = 'asdf' WHERE name = 'foo'", []V{}, Records{}}, - {"Delete", "DELETE FROM db.person WHERE name = 'asdf'", []V{}, Records{}}, - } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - rows, err := db.Query(c.Query) - require.NoError(t, err, "Query") - - var i int - for ; rows.Next(); i++ { - require.NoError(t, rows.Scan(c.Pointers...), "Scan") - values := c.Pointers.Values() - - if i >= len(c.Expect) { - t.Errorf("Got row %d, expected %d total: %v", i+1, len(c.Expect), values) - continue - } - - assert.EqualValues(t, c.Expect[i], values, "Values") - } - - require.NoError(t, rows.Err(), "Rows.Err") - - if i < len(c.Expect) { - t.Errorf("Expected %d row(s), got %d", len(c.Expect), i) - } - }) - } -} - -func TestExec(t *testing.T) { - mtb, records := personMemTable("db", "person") - db := sqlOpen(t, mtb, t.Name()) - - cases := []struct { - Name, Statement string - RowsAffected int - }{ - {"Insert", `INSERT INTO db.person VALUES ('asdf', 'qwer', '["zxcv"]', NOW())`, 1}, - {"Update", "UPDATE db.person SET name = 'foo' WHERE name = 'asdf'", 1}, - {"Delete", "DELETE FROM db.person WHERE name = 'foo'", 1}, - {"Delete All", "DELETE FROM db.person WHERE LENGTH(name) < 100", len(records)}, - } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - res, err := db.Exec(c.Statement) - require.NoError(t, err, "Exec") - - count, err := res.RowsAffected() - require.NoError(t, err, "RowsAffected") - assert.EqualValues(t, c.RowsAffected, count, "RowsAffected") - }) - } - - errCases := []struct { - Name, Statement string - Error string - }{ - {"Select", "SELECT * FROM db.person", "no result"}, - } - - for _, c := range errCases { - t.Run(c.Name, func(t *testing.T) { - res, err := db.Exec(c.Statement) - require.NoError(t, err, "Exec") - - _, err = res.RowsAffected() - require.Error(t, err, "RowsAffected") - assert.Equal(t, c.Error, err.Error()) - }) - } -} diff --git a/driver/fixtures_test.go b/driver/fixtures_test.go deleted file mode 100644 index 3da368b08e..0000000000 --- a/driver/fixtures_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package driver_test - -import ( - "sync" - "time" - - "github.com/dolthub/go-mysql-server/driver" - "github.com/dolthub/go-mysql-server/memory" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/information_schema" -) - -type memTable struct { - DatabaseName string - TableName string - Schema sql.PrimaryKeySchema - Records Records - - once sync.Once - dbProvider sql.DatabaseProvider -} - -func (f *memTable) Resolve(name string, _ *driver.Options) (string, sql.DatabaseProvider, error) { - f.once.Do(func() { - database := memory.NewDatabase(f.DatabaseName) - - table := memory.NewTable(f.TableName, f.Schema, database.GetForeignKeyCollection()) - - if f.Records != nil { - ctx := sql.NewEmptyContext() - for _, row := range f.Records { - table.Insert(ctx, sql.NewRow(row...)) - } - } - - database.AddTable(f.TableName, table) - - pro := memory.NewMemoryDBProvider( - database, - information_schema.NewInformationSchemaDatabase()) - f.dbProvider = pro - }) - - return name, f.dbProvider, nil -} - -func personMemTable(database, table string) (*memTable, Records) { - records := Records{ - []V{"John Doe", "john@doe.com", []V{"555-555-555"}, time.Now()}, - []V{"John Doe", "johnalt@doe.com", []V{}, time.Now()}, - []V{"Jane Doe", "jane@doe.com", []V{}, time.Now()}, - []V{"Evil Bob", "evilbob@gmail.com", []V{"555-666-555", "666-666-666"}, time.Now()}, - } - - mtb := &memTable{ - DatabaseName: database, - TableName: table, - Schema: sql.NewPrimaryKeySchema(sql.Schema{ - {Name: "name", Type: sql.Text, Nullable: false, Source: table}, - {Name: "email", Type: sql.Text, Nullable: false, Source: table}, - {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: table}, - {Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: table}, - }), - Records: records, - } - - return mtb, records -} diff --git a/driver/helpers_test.go b/driver/helpers_test.go deleted file mode 100644 index 5f796dc008..0000000000 --- a/driver/helpers_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package driver_test - -import ( - "database/sql" - "reflect" - "sync" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/dolthub/go-mysql-server/driver" -) - -type V = interface{} - -var driverMu sync.Mutex -var drivers = map[driver.Provider]*driver.Driver{} - -func sqlOpen(t *testing.T, provider driver.Provider, dsn string) *sql.DB { - driverMu.Lock() - drv, ok := drivers[provider] - if !ok { - drv = driver.New(provider, nil) - drivers[provider] = drv - } - driverMu.Unlock() - - conn, err := drv.OpenConnector(dsn) - require.NoError(t, err) - return sql.OpenDB(conn) -} - -type Pointers []V - -func (ptrs Pointers) Values() []V { - values := make([]V, len(ptrs)) - for i := range values { - values[i] = reflect.ValueOf(ptrs[i]).Elem().Interface() - } - return values -} - -type Records [][]V - -func (records Records) Rows(rows ...int) Records { - result := make(Records, len(rows)) - - for i := range rows { - result[i] = records[rows[i]] - } - - return result -} - -func (records Records) Columns(cols ...int) Records { - result := make(Records, len(records)) - - for i := range records { - result[i] = make([]V, len(cols)) - for j := range cols { - result[i][j] = records[i][cols[j]] - } - } - - return result -} diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 0ac30ef706..18f5ab129c 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -17,6 +17,7 @@ package enginetest import ( "context" "fmt" + "io" "net" "strings" "testing" @@ -5392,6 +5393,115 @@ func TestPrepared(t *testing.T, harness Harness) { } } +func TestTypesOverWire(t *testing.T, h Harness, sessionBuilder server.SessionBuilder) { + harness, ok := h.(ClientHarness) + if !ok { + t.Skip("Cannot run TestTypesOverWire as the harness must implement ClientHarness") + } + harness.Setup(setup.MydbData) + + port := getEmptyPort(t) + for _, script := range queries.TypeWireTests { + t.Run(script.Name, func(t *testing.T) { + ctx := NewContextWithClient(harness, sql.Client{ + User: "root", + Address: "localhost", + }) + serverConfig := server.Config{ + Protocol: "tcp", + Address: fmt.Sprintf("localhost:%d", port), + MaxConnections: 1000, + } + + engine := mustNewEngine(t, harness) + defer engine.Close() + engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + for _, statement := range script.SetUpScript { + if sh, ok := harness.(SkippingHarness); ok { + if sh.SkipQueryTest(statement) { + t.Skip() + } + } + RunQueryWithContext(t, engine, harness, ctx, statement) + } + + s, err := server.NewServer(serverConfig, engine, sessionBuilder, nil) + require.NoError(t, err) + go func() { + err := s.Start() + require.NoError(t, err) + }() + defer func() { + require.NoError(t, s.Close()) + }() + + conn, err := dbr.Open("mysql", fmt.Sprintf("root:@tcp(localhost:%d)/", port), nil) + require.NoError(t, err) + _, err = conn.Exec("USE mydb;") + require.NoError(t, err) + for queryIdx, query := range script.Queries { + r, err := conn.Query(query) + if assert.NoError(t, err) { + sch, engineIter, err := engine.Query(ctx, query) + require.NoError(t, err) + expectedRowSet := script.Results[queryIdx] + expectedRowIdx := 0 + var engineRow sql.Row + for engineRow, err = engineIter.Next(ctx); err == nil; engineRow, err = engineIter.Next(ctx) { + if !assert.True(t, r.Next()) { + break + } + expectedRow := expectedRowSet[expectedRowIdx] + expectedRowIdx++ + connRow := make([]*string, len(engineRow)) + interfaceRow := make([]any, len(connRow)) + for i := range connRow { + interfaceRow[i] = &connRow[i] + } + err = r.Scan(interfaceRow...) + if !assert.NoError(t, err) { + break + } + expectedEngineRow := make([]*string, len(engineRow)) + for i := range engineRow { + sqlVal, err := sch[i].Type.SQL(nil, engineRow[i]) + if !assert.NoError(t, err) { + break + } + if !sqlVal.IsNull() { + str := sqlVal.ToString() + expectedEngineRow[i] = &str + } + } + + for i := range expectedEngineRow { + expectedVal := expectedEngineRow[i] + connVal := connRow[i] + if !assert.Equal(t, expectedVal == nil, connVal == nil) { + continue + } + if expectedVal != nil { + assert.Equal(t, *expectedVal, *connVal) + if script.Name == "JSON" { + // Different integrators may return their JSON strings with different spacing, so we + // special case the test since the spacing is not significant + *connVal = strings.Replace(*connVal, `, `, `,`, -1) + *connVal = strings.Replace(*connVal, `: "`, `:"`, -1) + } + assert.Equal(t, expectedRow[i], *connVal) + } + } + } + assert.True(t, err == io.EOF) + assert.False(t, r.Next()) + require.NoError(t, r.Close()) + } + } + require.NoError(t, conn.Close()) + }) + } +} + type memoryPersister struct { users []*mysql_db.User roles []*mysql_db.RoleEdge diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 7656860b12..7d4d6f4a37 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -19,11 +19,11 @@ import ( "log" "testing" - "github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup" - "github.com/dolthub/go-mysql-server/enginetest" "github.com/dolthub/go-mysql-server/enginetest/queries" + "github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup" "github.com/dolthub/go-mysql-server/memory" + "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/expression" @@ -475,12 +475,6 @@ func TestLoadDataPrepared(t *testing.T) { } func TestScriptsPrepared(t *testing.T) { - //TODO: when foreign keys are implemented in the memory table, we can do the following test - for i := len(queries.ScriptTests) - 1; i >= 0; i-- { - if queries.ScriptTests[i].Name == "failed statements data validation for DELETE, REPLACE" { - queries.ScriptTests = append(queries.ScriptTests[:i], queries.ScriptTests[i+1:]...) - } - } enginetest.TestScriptsPrepared(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) } @@ -757,6 +751,10 @@ func TestKeylessUniqueIndex(t *testing.T) { enginetest.TestKeylessUniqueIndex(t, enginetest.NewDefaultMemoryHarness()) } +func TestTypesOverWire(t *testing.T) { + enginetest.TestTypesOverWire(t, enginetest.NewDefaultMemoryHarness(), server.DefaultSessionBuilder) +} + func mergableIndexDriver(dbs []sql.Database) sql.IndexDriver { return memory.NewIndexDriver("mydb", map[string][]sql.DriverIndex{ "mytable": { diff --git a/enginetest/mysqlshim/connection.go b/enginetest/mysqlshim/connection.go index adf7dd88fa..e987b18b69 100644 --- a/enginetest/mysqlshim/connection.go +++ b/enginetest/mysqlshim/connection.go @@ -40,6 +40,10 @@ func NewMySQLShim(user string, password string, host string, port int) (*MySQLSh if err != nil { return nil, err } + err = conn.Ping() + if err != nil { + return nil, err + } return &MySQLShim{conn, make(map[string]string)}, nil } diff --git a/enginetest/mysqlshim/iter.go b/enginetest/mysqlshim/iter.go index a074a19f5e..0fc7c4664c 100644 --- a/enginetest/mysqlshim/iter.go +++ b/enginetest/mysqlshim/iter.go @@ -45,12 +45,12 @@ func newMySQLIter(rows *dsql.Rows) mysqlIter { scanType = reflect.TypeOf("") case reflect.TypeOf(dsql.NullBool{}): scanType = reflect.TypeOf(true) - //case reflect.TypeOf(dsql.NullByte{}): // Not supported in go 1.15, need to upgrade to 1.17 - // scanType = reflect.TypeOf(byte(0)) + case reflect.TypeOf(dsql.NullByte{}): + scanType = reflect.TypeOf(byte(0)) case reflect.TypeOf(dsql.NullFloat64{}): scanType = reflect.TypeOf(float64(0)) - //case reflect.TypeOf(dsql.NullInt16{}): // Not supported in go 1.15, need to upgrade to 1.17 - // scanType = reflect.TypeOf(int16(0)) + case reflect.TypeOf(dsql.NullInt16{}): + scanType = reflect.TypeOf(int16(0)) case reflect.TypeOf(dsql.NullInt32{}): scanType = reflect.TypeOf(int32(0)) case reflect.TypeOf(dsql.NullInt64{}): diff --git a/enginetest/mysqlshim/mysql_harness.go b/enginetest/mysqlshim/mysql_harness.go index 96bc761fdd..f3bc954e11 100644 --- a/enginetest/mysqlshim/mysql_harness.go +++ b/enginetest/mysqlshim/mysql_harness.go @@ -15,6 +15,7 @@ package mysqlshim import ( + "context" "fmt" "strings" "testing" @@ -30,16 +31,36 @@ import ( type MySQLHarness struct { shim *MySQLShim skippedQueries map[string]struct{} + setupData []setup.SetupScript + session sql.Session } -func (m *MySQLHarness) Setup(source ...[]setup.SetupScript) { - //TODO implement me - panic("implement me") +//TODO: refactor to remove enginetest cycle +var _ enginetest.Harness = (*MySQLHarness)(nil) +var _ enginetest.IndexHarness = (*MySQLHarness)(nil) +var _ enginetest.ForeignKeyHarness = (*MySQLHarness)(nil) +var _ enginetest.KeylessTableHarness = (*MySQLHarness)(nil) +var _ enginetest.ClientHarness = (*MySQLHarness)(nil) +var _ enginetest.SkippingHarness = (*MySQLHarness)(nil) + +func (m *MySQLHarness) Setup(setupData ...[]setup.SetupScript) { + m.setupData = nil + for i := range setupData { + m.setupData = append(m.setupData, setupData[i]...) + } + return } func (m *MySQLHarness) NewEngine(t *testing.T) (*sqle.Engine, error) { - //TODO implement me - panic("implement me") + return enginetest.NewEngineWithProviderSetup(t, m, m.shim, m.setupData) +} + +func (m *MySQLHarness) NewContextWithClient(client sql.Client) *sql.Context { + session := sql.NewBaseSessionWithClientServer("address", client, 1) + return sql.NewContext( + context.Background(), + sql.WithSession(session), + ) } func (m *MySQLHarness) Cleanup() error { @@ -58,19 +79,13 @@ type MySQLTable struct { tableName string } -var _ enginetest.Harness = (*MySQLHarness)(nil) -var _ enginetest.SkippingHarness = (*MySQLHarness)(nil) -var _ enginetest.IndexHarness = (*MySQLHarness)(nil) -var _ enginetest.ForeignKeyHarness = (*MySQLHarness)(nil) -var _ enginetest.KeylessTableHarness = (*MySQLHarness)(nil) - // NewMySQLHarness returns a new MySQLHarness. func NewMySQLHarness(user string, password string, host string, port int) (*MySQLHarness, error) { shim, err := NewMySQLShim(user, password, host, port) if err != nil { return nil, err } - return &MySQLHarness{shim, make(map[string]struct{})}, nil + return &MySQLHarness{shim, make(map[string]struct{}), nil, nil}, nil } // Parallelism implements the interface Harness. @@ -126,7 +141,14 @@ func (m *MySQLHarness) NewTable(db sql.Database, name string, schema sql.Primary // NewContext implements the interface Harness. func (m *MySQLHarness) NewContext() *sql.Context { - return sql.NewEmptyContext() + if m.session == nil { + m.session = enginetest.NewBaseSession() + } + + return sql.NewContext( + context.Background(), + sql.WithSession(m.session), + ) } // SkipQueryTest implements the interface SkippingHarness. diff --git a/enginetest/queries/priv_auth_queries.go b/enginetest/queries/priv_auth_queries.go index fb40711c87..d964c43a58 100644 --- a/enginetest/queries/priv_auth_queries.go +++ b/enginetest/queries/priv_auth_queries.go @@ -311,36 +311,36 @@ var UserPrivTests = []UserPrivilegeTest{ { "localhost", // Host "root", // User - "Y", // Select_priv - "Y", // Insert_priv - "Y", // Update_priv - "Y", // Delete_priv - "Y", // Create_priv - "Y", // Drop_priv - "Y", // Reload_priv - "Y", // Shutdown_priv - "Y", // Process_priv - "Y", // File_priv - "Y", // Grant_priv - "Y", // References_priv - "Y", // Index_priv - "Y", // Alter_priv - "Y", // Show_db_priv - "Y", // Super_priv - "Y", // Create_tmp_table_priv - "Y", // Lock_tables_priv - "Y", // Execute_priv - "Y", // Repl_slave_priv - "Y", // Repl_client_priv - "Y", // Create_view_priv - "Y", // Show_view_priv - "Y", // Create_routine_priv - "Y", // Alter_routine_priv - "Y", // Create_user_priv - "Y", // Event_priv - "Y", // Trigger_priv - "Y", // Create_tablespace_priv - "", // ssl_type + uint16(2), // Select_priv + uint16(2), // Insert_priv + uint16(2), // Update_priv + uint16(2), // Delete_priv + uint16(2), // Create_priv + uint16(2), // Drop_priv + uint16(2), // Reload_priv + uint16(2), // Shutdown_priv + uint16(2), // Process_priv + uint16(2), // File_priv + uint16(2), // Grant_priv + uint16(2), // References_priv + uint16(2), // Index_priv + uint16(2), // Alter_priv + uint16(2), // Show_db_priv + uint16(2), // Super_priv + uint16(2), // Create_tmp_table_priv + uint16(2), // Lock_tables_priv + uint16(2), // Execute_priv + uint16(2), // Repl_slave_priv + uint16(2), // Repl_client_priv + uint16(2), // Create_view_priv + uint16(2), // Show_view_priv + uint16(2), // Create_routine_priv + uint16(2), // Alter_routine_priv + uint16(2), // Create_user_priv + uint16(2), // Event_priv + uint16(2), // Trigger_priv + uint16(2), // Create_tablespace_priv + uint16(1), // ssl_type "", // ssl_cipher "", // x509_issuer "", // x509_subject @@ -350,12 +350,12 @@ var UserPrivTests = []UserPrivilegeTest{ uint32(0), // max_user_connections "mysql_native_password", // plugin "", // authentication_string - "N", // password_expired + uint16(1), // password_expired time.Unix(1, 0).UTC(), // password_last_changed nil, // password_lifetime - "N", // account_locked - "Y", // Create_role_priv - "Y", // Drop_role_priv + uint16(1), // account_locked + uint16(2), // Create_role_priv + uint16(2), // Drop_role_priv nil, // Password_reuse_history nil, // Password_reuse_time nil, // Password_require_current @@ -461,7 +461,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.db;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "Y", "N", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "N"}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", uint16(2), uint16(1), uint16(2), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(2), uint16(1), uint16(1)}}, }, { User: "root", @@ -473,7 +473,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.db;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "N"}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", uint16(2), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(2), uint16(1), uint16(1)}}, }, { User: "root", @@ -493,7 +493,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.db;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "Y", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "N"}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", uint16(2), uint16(2), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(2), uint16(1), uint16(1)}}, }, }, }, @@ -514,7 +514,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.tables_priv;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), "Select,Delete,Drop", ""}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), uint64(0b101001), uint64(0)}}, }, { User: "root", @@ -526,7 +526,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.tables_priv;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), "Select,Drop", ""}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), uint64(0b100001), uint64(0)}}, }, { User: "root", @@ -546,7 +546,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.tables_priv;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), "References,Index", ""}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), uint64(0b110000000), uint64(0)}}, }, }, }, @@ -569,7 +569,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "Y"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(2)}}, }, { User: "root", @@ -587,7 +587,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "N"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, }, }, }, @@ -616,7 +616,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv, Insert_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "Y", "Y"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(2), uint16(2)}}, }, { User: "root", @@ -640,7 +640,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv, Insert_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "N", "N"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(1), uint16(1)}}, }, }, }, @@ -654,7 +654,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, account_locked FROM mysql.user WHERE User = 'test_role';", - Expected: []sql.Row{{"test_role", "%", "Y"}}, + Expected: []sql.Row{{"test_role", "%", uint16(2)}}, }, }, }, @@ -691,7 +691,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "tester", @@ -703,7 +703,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "N"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, }, }, }, @@ -729,7 +729,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "root", @@ -785,7 +785,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "root", @@ -847,7 +847,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "root", diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 9751f23ca0..dc5a153a5b 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -1359,7 +1359,7 @@ var ScriptTests = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "b", "b"}, {2, "a", "a"}}, + Expected: []sql.Row{{1, uint16(2), uint64(2)}, {2, uint16(1), uint64(1)}}, }, { Query: "UPDATE test SET v1 = 3 WHERE v1 = 2;", @@ -1367,7 +1367,7 @@ var ScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "c", "b"}, {2, "a", "a"}}, + Expected: []sql.Row{{1, uint16(3), uint64(2)}, {2, uint16(1), uint64(1)}}, }, { Query: "UPDATE test SET v2 = 3 WHERE 2 = v2;", @@ -1375,7 +1375,7 @@ var ScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "c", "a,b"}, {2, "a", "a"}}, + Expected: []sql.Row{{1, uint16(3), uint64(3)}, {2, uint16(1), uint64(1)}}, }, }, }, diff --git a/enginetest/queries/type_wire_queries.go b/enginetest/queries/type_wire_queries.go new file mode 100644 index 0000000000..32aee3de28 --- /dev/null +++ b/enginetest/queries/type_wire_queries.go @@ -0,0 +1,727 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queries + +import "github.com/dolthub/go-mysql-server/sql" + +// TypeWireTest is used to ensure that types are properly represented over the wire (vs being directly returned from the +// engine). +type TypeWireTest struct { + Name string + SetUpScript []string + Queries []string + Results [][]sql.Row +} + +// TypeWireTests are used to ensure that types are properly represented over the wire (vs being directly returned from +// the engine). +var TypeWireTests = []TypeWireTest{ + { + Name: "TINYINT", + SetUpScript: []string{ + `CREATE TABLE test (pk TINYINT PRIMARY KEY, v1 TINYINT);`, + `INSERT INTO test VALUES (-75, "-25"), (0, 0), (107.2, 0025), (120, -120);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk > "119";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75", "-26"}, {"0", "0"}, {"107", "25"}}, + {{"-26", "-75"}, {"0", "0"}, {"25", "107"}}, + {{"-52", "-74"}, {"0", "1"}, {"50", "108"}}, + }, + }, + { + Name: "SMALLINT", + SetUpScript: []string{ + `CREATE TABLE test (pk SMALLINT PRIMARY KEY, v1 SMALLINT);`, + `INSERT INTO test VALUES (-75, "-2531"), (0, 0), (2547.2, 03325), (9999, 9999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk >= "9999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75", "-2532"}, {"0", "0"}, {"2547", "3325"}}, + {{"-2532", "-75"}, {"0", "0"}, {"3325", "2547"}}, + {{"-5064", "-74"}, {"0", "1"}, {"6650", "2548"}}, + }, + }, + { + Name: "MEDIUMINT", + SetUpScript: []string{ + `CREATE TABLE test (pk MEDIUMINT PRIMARY KEY, v1 MEDIUMINT);`, + `INSERT INTO test VALUES (0, 0), (2547.2, 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"0", "0"}, {"2547", "3325"}}, + {{"0", "0"}, {"3325", "2547"}}, + {{"0", "1"}, {"6650", "2548"}}, + }, + }, + { + Name: "INT", + SetUpScript: []string{ + `CREATE TABLE test (pk INT PRIMARY KEY, v1 INT);`, + `INSERT INTO test VALUES (-75, "-2531"), (0, 0), (2547.2, 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75", "-2532"}, {"0", "0"}, {"2547", "3325"}}, + {{"-2532", "-75"}, {"0", "0"}, {"3325", "2547"}}, + {{"-5064", "-74"}, {"0", "1"}, {"6650", "2548"}}, + }, + }, + { + Name: "BIGINT", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 BIGINT);`, + `INSERT INTO test VALUES (-75, "-2531"), (0, 0), (2547.2, 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75", "-2532"}, {"0", "0"}, {"2547", "3325"}}, + {{"-2532", "-75"}, {"0", "0"}, {"3325", "2547"}}, + {{"-5064", "-74"}, {"0", "1"}, {"6650", "2548"}}, + }, + }, + { + Name: "TINYINT UNSIGNED", + SetUpScript: []string{ + `CREATE TABLE test (pk TINYINT UNSIGNED PRIMARY KEY, v1 TINYINT UNSIGNED);`, + `INSERT INTO test VALUES (0, 0), (25, "26"), (32.1, 0126), (255, 255);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk > 0 AND pk < 30;`, + `DELETE FROM test WHERE pk >= "255";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"0", "0"}, {"25", "25"}, {"32", "126"}}, + {{"0", "0"}, {"25", "25"}, {"126", "32"}}, + {{"0", "1"}, {"50", "26"}, {"252", "33"}}, + }, + }, + { + Name: "SMALLINT UNSIGNED", + SetUpScript: []string{ + `CREATE TABLE test (pk SMALLINT UNSIGNED PRIMARY KEY, v1 SMALLINT UNSIGNED);`, + `INSERT INTO test VALUES (0, 0), (25, "2531"), (2547.2, 03325), (9999, 9999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk > 0 AND pk < 100;`, + `DELETE FROM test WHERE pk >= "9999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"0", "0"}, {"25", "2530"}, {"2547", "3325"}}, + {{"0", "0"}, {"2530", "25"}, {"3325", "2547"}}, + {{"0", "1"}, {"5060", "26"}, {"6650", "2548"}}, + }, + }, + { + Name: "MEDIUMINT UNSIGNED", + SetUpScript: []string{ + `CREATE TABLE test (pk MEDIUMINT UNSIGNED PRIMARY KEY, v1 MEDIUMINT UNSIGNED);`, + `INSERT INTO test VALUES (75, "2531"), (0, 0), (2547.2, 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 + 1 WHERE pk < 100;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"0", "1"}, {"75", "2532"}, {"2547", "3325"}}, + {{"1", "0"}, {"2532", "75"}, {"3325", "2547"}}, + {{"2", "1"}, {"5064", "76"}, {"6650", "2548"}}, + }, + }, + { + Name: "INT UNSIGNED", + SetUpScript: []string{ + `CREATE TABLE test (pk INT UNSIGNED PRIMARY KEY, v1 INT UNSIGNED);`, + `INSERT INTO test VALUES (75, "2531"), (0, 0), (2547.2, 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 + 1 WHERE pk < 100;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"0", "1"}, {"75", "2532"}, {"2547", "3325"}}, + {{"1", "0"}, {"2532", "75"}, {"3325", "2547"}}, + {{"2", "1"}, {"5064", "76"}, {"6650", "2548"}}, + }, + }, + { + Name: "BIGINT UNSIGNED", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT UNSIGNED PRIMARY KEY, v1 BIGINT UNSIGNED);`, + `INSERT INTO test VALUES (75, "2531"), (0, 0), (2547.2, 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 + 1 WHERE pk < 100;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"0", "1"}, {"75", "2532"}, {"2547", "3325"}}, + {{"1", "0"}, {"2532", "75"}, {"3325", "2547"}}, + {{"2", "1"}, {"5064", "76"}, {"6650", "2548"}}, + }, + }, + { + Name: "FLOAT", + SetUpScript: []string{ + `CREATE TABLE test (pk FLOAT PRIMARY KEY, v1 FLOAT);`, + `INSERT INTO test VALUES (-75.11, "-2531"), (0, 0), ("2547.2", 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75.11", "-2532"}, {"0", "0"}, {"2547.2", "3325"}}, + {{"-2532", "-75.11"}, {"0", "0"}, {"3325", "2547.2"}}, + {{"-5064", "-74.11000061035156"}, {"0", "1"}, {"6650", "2548.199951171875"}}, + }, + }, + { + Name: "DOUBLE", + SetUpScript: []string{ + `CREATE TABLE test (pk DOUBLE PRIMARY KEY, v1 DOUBLE);`, + `INSERT INTO test VALUES (-75.11, "-2531"), (0, 0), ("2547.2", 03325), (999999, 999999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk > "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75.11", "-2532"}, {"0", "0"}, {"2547.2", "3325"}}, + {{"-2532", "-75.11"}, {"0", "0"}, {"3325", "2547.2"}}, + {{"-5064", "-74.11"}, {"0", "1"}, {"6650", "2548.2"}}, + }, + }, + { + Name: "DECIMAL", + SetUpScript: []string{ + `CREATE TABLE test (pk DECIMAL(5,0) PRIMARY KEY, v1 DECIMAL(25,5));`, + `INSERT INTO test VALUES (-75, "-2531.356"), (0, 0), (2547.2, 03325), (99999, 999999);`, + `UPDATE test SET v1 = v1 - 1 WHERE pk < 0;`, + `DELETE FROM test WHERE pk >= "99999";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1*2, pk+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-75", "-2532.35600"}, {"0", "0.00000"}, {"2547", "3325.00000"}}, + {{"-2532.35600", "-75"}, {"0.00000", "0"}, {"3325.00000", "2547"}}, + {{"-5064.712", "-74"}, {"0", "1"}, {"6650", "2548"}}, + }, + }, + { + Name: "BIT", + SetUpScript: []string{ + `CREATE TABLE test (pk BIT(55) PRIMARY KEY, v1 BIT(1), v2 BIT(24));`, + `INSERT INTO test VALUES (75, 0, "21"), (0, 0, 0), (2547.2, 1, 03325), (999999, 1, 999999);`, + `UPDATE test SET v2 = v2 - 1 WHERE pk > 0 AND pk < 100;`, + `DELETE FROM test WHERE pk > 99999;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v2, v1, pk FROM test ORDER BY pk;`, + `SELECT v1*1, pk/10, v2+1 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"\x00\x00\x00\x00\x00\x00\x00", "\x00", "\x00\x00\x00"}, {"\x00\x00\x00\x00\x00\x00K", "\x00", "\x0020"}, {"\x00\x00\x00\x00\x00\t\xf3", "", "\x00 \xfd"}}, + {{"\x00\x00\x00", "\x00", "\x00\x00\x00\x00\x00\x00\x00"}, {"\x0020", "\x00", "\x00\x00\x00\x00\x00\x00K"}, {"\x00 \xfd", "", "\x00\x00\x00\x00\x00\t\xf3"}}, + {{"0", "0", "1"}, {"0", "7.5", "12849"}, {"1", "254.7", "3326"}}, + }, + }, + { + Name: "YEAR", + SetUpScript: []string{ + `CREATE TABLE test (pk YEAR PRIMARY KEY, v1 YEAR);`, + `INSERT INTO test VALUES (1901, 1901), (1950, "1950"), (1979.2, 01986), (2122, 2122);`, + `UPDATE test SET v1 = v1 + 1 WHERE pk < 1975;`, + `DELETE FROM test WHERE pk > "2100";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT v1+3, pk+2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1901", "1902"}, {"1950", "1951"}, {"1979", "1986"}}, + {{"1902", "1901"}, {"1951", "1950"}, {"1986", "1979"}}, + {{"1905", "1903"}, {"1954", "1952"}, {"1989", "1981"}}, + }, + }, + { + Name: "TIMESTAMP", + SetUpScript: []string{ + `CREATE TABLE test (pk TIMESTAMP PRIMARY KEY, v1 TIMESTAMP);`, + `INSERT INTO test VALUES ("1980-04-12 12:02:11", "1986-08-02 17:04:22"), ("1999-11-28 13:06:33", "2022-01-14 15:08:44"), ("2020-05-06 18:10:55", "1975-09-15 11:12:16");`, + `UPDATE test SET v1 = "2000-01-01 00:00:00" WHERE pk < "1990-01-01 00:00:00";`, + `DELETE FROM test WHERE pk > "2015-01-01 00:00:00";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v1 FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1980-04-12 12:02:11", "2000-01-01 00:00:00"}, {"1999-11-28 13:06:33", "2022-01-14 15:08:44"}}, + {{"1980-04-12 12:02:11", "2000-01-01 00:00:00"}, {"1999-11-28 13:06:33", "2022-01-14 15:08:44"}}, + {{"2000-01-01 00:00:00", "1980-04-12 12:02:11"}, {"2022-01-14 15:08:44", "1999-11-28 13:06:33"}}, + }, + }, + { + Name: "DATETIME", + SetUpScript: []string{ + `CREATE TABLE test (pk DATETIME PRIMARY KEY, v1 DATETIME);`, + `INSERT INTO test VALUES ("1000-04-12 12:02:11", "1986-08-02 17:04:22"), ("1999-11-28 13:06:33", "2022-01-14 15:08:44"), ("5020-05-06 18:10:55", "1975-09-15 11:12:16");`, + `UPDATE test SET v1 = "2000-01-01 00:00:00" WHERE pk < "1990-01-01 00:00:00";`, + `DELETE FROM test WHERE pk > "5000-01-01 00:00:00";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v1 FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1000-04-12 12:02:11", "2000-01-01 00:00:00"}, {"1999-11-28 13:06:33", "2022-01-14 15:08:44"}}, + {{"1000-04-12 12:02:11", "2000-01-01 00:00:00"}, {"1999-11-28 13:06:33", "2022-01-14 15:08:44"}}, + {{"2000-01-01 00:00:00", "1000-04-12 12:02:11"}, {"2022-01-14 15:08:44", "1999-11-28 13:06:33"}}, + }, + }, + { + Name: "DATE", + SetUpScript: []string{ + `CREATE TABLE test (pk DATE PRIMARY KEY, v1 DATE);`, + `INSERT INTO test VALUES ("1000-04-12", "1986-08-02"), ("1999-11-28", "2022-01-14"), ("5020-05-06", "1975-09-15");`, + `UPDATE test SET v1 = "2000-01-01" WHERE pk < "1990-01-01";`, + `DELETE FROM test WHERE pk > "5000-01-01";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v1 FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1000-04-12", "2000-01-01"}, {"1999-11-28", "2022-01-14"}}, + {{"1000-04-12", "2000-01-01"}, {"1999-11-28", "2022-01-14"}}, + {{"2000-01-01", "1000-04-12"}, {"2022-01-14", "1999-11-28"}}, + }, + }, + { + Name: "TIME", + SetUpScript: []string{ + `CREATE TABLE test (pk TIME PRIMARY KEY, v1 TIME);`, + `INSERT INTO test VALUES ("-800:00:00", "-20:21:22"), ("00:00:00", "00:00:00"), ("10:26:57", "30:53:14"), ("700:23:51", "300:25:52");`, + `UPDATE test SET v1 = "-120:12:20" WHERE pk < "00:00:00";`, + `DELETE FROM test WHERE pk > "600:00:00";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v1 FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"-800:00:00", "-120:12:20"}, {"00:00:00", "00:00:00"}, {"10:26:57", "30:53:14"}}, + {{"-800:00:00", "-120:12:20"}, {"00:00:00", "00:00:00"}, {"10:26:57", "30:53:14"}}, + {{"-120:12:20", "-800:00:00"}, {"00:00:00", "00:00:00"}, {"30:53:14", "10:26:57"}}, + }, + }, + { + Name: "CHAR", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 CHAR(5), v2 CHAR(10));`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = "a-c" WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "a-c", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "a-c"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"a-cr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "VARCHAR", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 VARCHAR(5), v2 VARCHAR(10));`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "BINARY", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 BINARY(5), v2 BINARY(10));`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = "a-c" WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc\x00\x00", "def\x00\x00\x00\x00\x00\x00\x00"}, {"2", "a-c\x00\x00", "123\x00\x00\x00\x00\x00\x00\x00"}, {"3", "__2\x00\x00", "456\x00\x00\x00\x00\x00\x00\x00"}}, + {{"1", "def\x00\x00\x00\x00\x00\x00\x00", "abc\x00\x00"}, {"2", "123\x00\x00\x00\x00\x00\x00\x00", "a-c\x00\x00"}, {"3", "456\x00\x00\x00\x00\x00\x00\x00", "__2\x00\x00"}}, + {{"abc\x00\x00r", "1", "def\x00\x00\x00\x00\x00\x00\x00"}, {"a-c\x00\x00r", "2", "123\x00\x00\x00\x00\x00\x00\x00"}, {"__2\x00\x00r", "3", "456\x00\x00\x00\x00\x00\x00\x00"}}, + }, + }, + { + Name: "VARBINARY", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 VARBINARY(5), v2 VARBINARY(10));`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "TINYTEXT", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 TINYTEXT, v2 TINYTEXT);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "TEXT", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 TEXT, v2 TEXT);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "MEDIUMTEXT", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 MEDIUMTEXT, v2 MEDIUMTEXT);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "LONGTEXT", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 LONGTEXT, v2 LONGTEXT);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "TINYBLOB", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 TINYBLOB, v2 TINYBLOB);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "BLOB", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 BLOB, v2 BLOB);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "MEDIUMBLOB", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 MEDIUMBLOB, v2 MEDIUMBLOB);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "LONGBLOB", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 LONGBLOB, v2 LONGBLOB);`, + `INSERT INTO test VALUES (1, "abc", "def"), (2, "c-a", "123"), (3, "__2", 456), (4, "?hi?", "\\n");`, + `UPDATE test SET v1 = CONCAT(v1, "x") WHERE pk = 2;`, + `DELETE FROM test WHERE pk = 4;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v2, v1 FROM test ORDER BY pk;`, + `SELECT CONCAT(v1, "r"), pk, v2 FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "abc", "def"}, {"2", "c-ax", "123"}, {"3", "__2", "456"}}, + {{"1", "def", "abc"}, {"2", "123", "c-ax"}, {"3", "456", "__2"}}, + {{"abcr", "1", "def"}, {"c-axr", "2", "123"}, {"__2r", "3", "456"}}, + }, + }, + { + Name: "ENUM", + SetUpScript: []string{ + `CREATE TABLE test (pk ENUM("a","b","c") PRIMARY KEY, v1 ENUM("x","y","z"));`, + `INSERT INTO test VALUES (1, 1), ("b", "y"), (3, "z");`, + `UPDATE test SET v1 = "x" WHERE pk = 2;`, + `DELETE FROM test WHERE pk > 2;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v1 FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"a", "x"}, {"b", "x"}}, + {{"a", "x"}, {"b", "x"}}, + {{"x", "a"}, {"x", "b"}}, + }, + }, + { + Name: "SET", + SetUpScript: []string{ + `CREATE TABLE test (pk SET("a","b","c") PRIMARY KEY, v1 SET("w","x","y","z"));`, + `INSERT INTO test VALUES (0, 1), ("b", "y"), ("b,c", "z,z"), ("a,c,b", 10);`, + `UPDATE test SET v1 = "y,x,w" WHERE pk >= 4`, + `DELETE FROM test WHERE pk > "b,c";`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT pk, v1 FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"", "w"}, {"b", "y"}, {"b,c", "w,x,y"}}, + {{"", "w"}, {"b", "y"}, {"b,c", "w,x,y"}}, + {{"w", ""}, {"y", "b"}, {"w,x,y", "b,c"}}, + }, + }, + //TODO: fix GEOMETRY and friends, basic queries are broken + //{ + // Name: "GEOMETRY", + // SetUpScript: []string{ + // `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 GEOMETRY);`, + // `INSERT INTO test VALUES (1, POINT(1, 2)), (2, LINESTRING(POINT(1, 2), POINT(3, 4))), (3, ST_GeomFromText('POLYGON((0 0,0 3,3 0,0 0),(1 1,1 2,2 1,1 1))'));`, + // }, + // Queries: []string{ + // `SELECT * FROM test ORDER BY pk;`, + // `SELECT pk, v1 FROM test ORDER BY pk;`, + // `SELECT ST_ASWKT(v1), pk FROM test ORDER BY pk;`, + // }, + //}, + //{ + // Name: "POINT", + // SetUpScript: []string{ + // `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 POINT);`, + // `INSERT INTO test VALUES (1, POINT(1, 2)), (2, POINT(3.4, 5.6)), (3, POINT(10, -20)), (4, POINT(1000, -1000));`, + // `DELETE FROM test WHERE pk = 4;`, + // }, + // Queries: []string{ + // `SELECT * FROM test ORDER BY pk;`, + // `SELECT pk, v1 FROM test ORDER BY pk;`, + // `SELECT ST_ASWKT(v1), pk FROM test ORDER BY pk;`, + // }, + //}, + //{ + // Name: "LINESTRING", + // SetUpScript: []string{ + // `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 LINESTRING);`, + // `INSERT INTO test VALUES (1, LINESTRING(POINT(1, 2), POINT(3, 4))), (2, LINESTRING(POINT(5, 6), POINT(7, 8)));`, + // }, + // Queries: []string{ + // `SELECT * FROM test ORDER BY pk;`, + // `SELECT pk, v1 FROM test ORDER BY pk;`, + // `SELECT ST_ASWKT(v1), pk FROM test ORDER BY pk;`, + // }, + //}, + //{ + // Name: "POLYGON", + // SetUpScript: []string{ + // `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 POLYGON);`, + // `INSERT INTO test VALUES (1, ST_GeomFromText('POLYGON((0 0,0 3,3 0,0 0),(1 1,1 2,2 1,1 1))'));`, + // }, + // Queries: []string{ + // `SELECT * FROM test ORDER BY pk;`, + // `SELECT pk, v1 FROM test ORDER BY pk;`, + // `SELECT ST_ASWKT(v1), pk FROM test ORDER BY pk;`, + // }, + //}, + { + Name: "JSON", + SetUpScript: []string{ + `CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 JSON);`, + `INSERT INTO test VALUES (1, '{"key1": {"key": "value"}}'), (2, '{"key1": "value1", "key2": "value2"}'), (3, '{"key1": {"key": [2,3]}}');`, + `UPDATE test SET v1 = '["a", 1]' WHERE pk = 1;`, + `DELETE FROM test WHERE pk = 3;`, + }, + Queries: []string{ + `SELECT * FROM test ORDER BY pk;`, + `SELECT v1, pk FROM test ORDER BY pk;`, + `SELECT pk, JSON_ARRAYAGG(v1) FROM (SELECT * FROM test ORDER BY pk) as sub GROUP BY v1 ORDER BY pk;`, + }, + Results: [][]sql.Row{ + {{"1", "[\"a\",1]"}, {"2", "{\"key1\":\"value1\",\"key2\":\"value2\"}"}}, + {{"[\"a\",1]", "1"}, {"{\"key1\":\"value1\",\"key2\":\"value2\"}", "2"}}, + {{"1", "[[\"a\",1]]"}, {"2", "[{\"key1\":\"value1\",\"key2\":\"value2\"}]"}}, + }, + }, +} diff --git a/enginetest/queries/variable_queries.go b/enginetest/queries/variable_queries.go index f16a9dded5..6c005d122f 100644 --- a/enginetest/queries/variable_queries.go +++ b/enginetest/queries/variable_queries.go @@ -76,7 +76,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@autocommit, @@session.sql_mode", Expected: []sql.Row{ - {1, ""}, + {1, uint64(0)}, }, }, { @@ -86,7 +86,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@autocommit, @@session.sql_mode", Expected: []sql.Row{ - {1, ""}, + {1, uint64(0)}, }, }, { @@ -189,7 +189,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@sql_mode", Expected: []sql.Row{ - {"ALLOW_INVALID_DATES"}, + {uint64(1)}, }, }, { @@ -199,7 +199,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@sql_mode", Expected: []sql.Row{ - {"ALLOW_INVALID_DATES"}, + {uint64(1)}, }, }, { @@ -209,7 +209,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@sql_mode", Expected: []sql.Row{ - {"ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE,STRICT_ALL_TABLES,STRICT_TRANS_TABLES,TRADITIONAL"}, + {uint64(0b10110000110100000100)}, }, }, // User variables diff --git a/enginetest/testdata.go b/enginetest/testdata.go index a7c92f12c1..dabd191bb7 100644 --- a/enginetest/testdata.go +++ b/enginetest/testdata.go @@ -199,21 +199,21 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, "first_row"), - sql.NewRow(2, "second_row"), - sql.NewRow(3, "third_row"), - sql.NewRow(4, `%`), - sql.NewRow(5, `'`), - sql.NewRow(6, `"`), - sql.NewRow(7, "\t"), - sql.NewRow(8, "\n"), - sql.NewRow(9, "\v"), - sql.NewRow(10, `test%test`), - sql.NewRow(11, `test'test`), - sql.NewRow(12, `test"test`), - sql.NewRow(13, "test\ttest"), - sql.NewRow(14, "test\ntest"), - sql.NewRow(15, "test\vtest"), + sql.NewRow(int64(1), "first_row"), + sql.NewRow(int64(2), "second_row"), + sql.NewRow(int64(3), "third_row"), + sql.NewRow(int64(4), `%`), + sql.NewRow(int64(5), `'`), + sql.NewRow(int64(6), `"`), + sql.NewRow(int64(7), "\t"), + sql.NewRow(int64(8), "\n"), + sql.NewRow(int64(9), "\v"), + sql.NewRow(int64(10), `test%test`), + sql.NewRow(int64(11), `test'test`), + sql.NewRow(int64(12), `test"test`), + sql.NewRow(int64(13), "test\ttest"), + sql.NewRow(int64(14), "test\ntest"), + sql.NewRow(int64(15), "test\vtest"), ) } else { t.Logf("Warning: could not create table %s: %s", "specialtable", err) @@ -252,10 +252,10 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 1, 2, 3, 4), - sql.NewRow(1, 10, 11, 12, 13, 14), - sql.NewRow(2, 20, 21, 22, 23, 24), - sql.NewRow(3, 30, 31, 32, 33, 34)) + sql.NewRow(int8(0), int8(0), int8(1), int8(2), int8(3), int8(4)), + sql.NewRow(int8(1), int8(10), int8(11), int8(12), int8(13), int8(14)), + sql.NewRow(int8(2), int8(20), int8(21), int8(22), int8(23), int8(24)), + sql.NewRow(int8(3), int8(30), int8(31), int8(32), int8(33), int8(34))) } else { t.Logf("Warning: could not create table %s: %s", "one_pk", err) } @@ -273,10 +273,10 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, "row one", sql.JSONDocument{Val: []interface{}{1, 2}}, sql.JSONDocument{Val: map[string]interface{}{"a": 2}}), - sql.NewRow(2, "row two", sql.JSONDocument{Val: []interface{}{3, 4}}, sql.JSONDocument{Val: map[string]interface{}{"b": 2}}), - sql.NewRow(3, "row three", sql.JSONDocument{Val: []interface{}{5, 6}}, sql.JSONDocument{Val: map[string]interface{}{"c": 2}}), - sql.NewRow(4, "row four", sql.JSONDocument{Val: []interface{}{7, 8}}, sql.JSONDocument{Val: map[string]interface{}{"d": 2}})) + sql.NewRow(int8(1), "row one", sql.JSONDocument{Val: []interface{}{1, 2}}, sql.JSONDocument{Val: map[string]interface{}{"a": 2}}), + sql.NewRow(int8(2), "row two", sql.JSONDocument{Val: []interface{}{3, 4}}, sql.JSONDocument{Val: map[string]interface{}{"b": 2}}), + sql.NewRow(int8(3), "row three", sql.JSONDocument{Val: []interface{}{5, 6}}, sql.JSONDocument{Val: map[string]interface{}{"c": 2}}), + sql.NewRow(int8(4), "row four", sql.JSONDocument{Val: []interface{}{7, 8}}, sql.JSONDocument{Val: map[string]interface{}{"d": 2}})) } else { t.Logf("Warning: could not create table %s: %s", "jsontable", err) } @@ -297,10 +297,10 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 0, 1, 2, 3, 4), - sql.NewRow(0, 1, 10, 11, 12, 13, 14), - sql.NewRow(1, 0, 20, 21, 22, 23, 24), - sql.NewRow(1, 1, 30, 31, 32, 33, 34)) + sql.NewRow(int8(0), int8(0), int8(0), int8(1), int8(2), int8(3), int8(4)), + sql.NewRow(int8(0), int8(1), int8(10), int8(11), int8(12), int8(13), int8(14)), + sql.NewRow(int8(1), int8(0), int8(20), int8(21), int8(22), int8(23), int8(24)), + sql.NewRow(int8(1), int8(1), int8(30), int8(31), int8(32), int8(33), int8(34))) } else { t.Logf("Warning: could not create table %s: %s", "two_pk", err) } @@ -317,14 +317,14 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 0), - sql.NewRow(1, 1, 1), - sql.NewRow(2, 2, 2), - sql.NewRow(3, 3, 3), - sql.NewRow(4, 4, 4), - sql.NewRow(5, 5, 5), - sql.NewRow(6, 6, 6), - sql.NewRow(7, 7, 7)) + sql.NewRow(int64(0), int64(0), int64(0)), + sql.NewRow(int64(1), int64(1), int64(1)), + sql.NewRow(int64(2), int64(2), int64(2)), + sql.NewRow(int64(3), int64(3), int64(3)), + sql.NewRow(int64(4), int64(4), int64(4)), + sql.NewRow(int64(5), int64(5), int64(5)), + sql.NewRow(int64(6), int64(6), int64(6)), + sql.NewRow(int64(7), int64(7), int64(7))) } else { t.Logf("Warning: could not create table %s: %s", "one_pk_two_idx", err) } @@ -342,14 +342,14 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 0, 0), - sql.NewRow(1, 0, 0, 1), - sql.NewRow(2, 0, 1, 0), - sql.NewRow(3, 0, 2, 2), - sql.NewRow(4, 1, 0, 0), - sql.NewRow(5, 2, 0, 3), - sql.NewRow(6, 3, 3, 0), - sql.NewRow(7, 4, 4, 4)) + sql.NewRow(int64(0), int64(0), int64(0), int64(0)), + sql.NewRow(int64(1), int64(0), int64(0), int64(1)), + sql.NewRow(int64(2), int64(0), int64(1), int64(0)), + sql.NewRow(int64(3), int64(0), int64(2), int64(2)), + sql.NewRow(int64(4), int64(1), int64(0), int64(0)), + sql.NewRow(int64(5), int64(2), int64(0), int64(3)), + sql.NewRow(int64(6), int64(3), int64(3), int64(0)), + sql.NewRow(int64(7), int64(4), int64(4), int64(4))) } else { t.Logf("Warning: could not create table %s: %s", "one_pk_three_idx", err) } @@ -383,9 +383,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(int64(1), "first row"), - sql.NewRow(int64(2), "second row"), - sql.NewRow(int64(3), "third row")) + sql.NewRow(int32(1), "first row"), + sql.NewRow(int32(2), "second row"), + sql.NewRow(int32(3), "third row")) } else { t.Logf("Warning: could not create table %s: %s", "tabletest", err) } @@ -510,11 +510,11 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), sql.NewRow(int64(1), nil, nil, nil), - sql.NewRow(int64(2), int64(2), 1, nil), - sql.NewRow(int64(3), nil, 0, nil), + sql.NewRow(int64(2), int64(2), int8(1), nil), + sql.NewRow(int64(3), nil, int8(0), nil), sql.NewRow(int64(4), int64(4), nil, float64(4)), - sql.NewRow(int64(5), nil, 1, float64(5)), - sql.NewRow(int64(6), int64(6), 0, float64(6))) + sql.NewRow(int64(5), nil, int8(1), float64(5)), + sql.NewRow(int64(6), int64(6), int8(0), float64(6))) } else { t.Logf("Warning: could not create table %s: %s", "niltable", err) } @@ -585,7 +585,7 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string t1, t2, "fourteen", - 0, + int8(0), nil, nil, )) @@ -607,9 +607,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, mustParseDate("2019-12-31T12:00:00Z"), mustParseTime("2020-01-01T12:00:00Z"), mustParseTime("2020-01-02T12:00:00Z"), mustSQLTime(3*time.Hour+10*time.Minute)), - sql.NewRow(2, mustParseDate("2020-01-03T12:00:00Z"), mustParseTime("2020-01-04T12:00:00Z"), mustParseTime("2020-01-05T12:00:00Z"), mustSQLTime(4*time.Hour+44*time.Second)), - sql.NewRow(3, mustParseDate("2020-01-07T00:00:00Z"), mustParseTime("2020-01-07T12:00:00Z"), mustParseTime("2020-01-07T12:00:01Z"), mustSQLTime(15*time.Hour+5*time.Millisecond)), + sql.NewRow(int64(1), mustParseDate("2019-12-31T12:00:00Z"), mustParseTime("2020-01-01T12:00:00Z"), mustParseTime("2020-01-02T12:00:00Z"), mustSQLTime(3*time.Hour+10*time.Minute)), + sql.NewRow(int64(2), mustParseDate("2020-01-03T12:00:00Z"), mustParseTime("2020-01-04T12:00:00Z"), mustParseTime("2020-01-05T12:00:00Z"), mustSQLTime(4*time.Hour+44*time.Second)), + sql.NewRow(int64(3), mustParseDate("2020-01-07T00:00:00Z"), mustParseTime("2020-01-07T12:00:00Z"), mustParseTime("2020-01-07T12:00:01Z"), mustSQLTime(15*time.Hour+5*time.Millisecond)), ) } }) @@ -666,9 +666,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, 1, "first row"), - sql.NewRow(2, 2, "second row"), - sql.NewRow(3, 3, "third row"), + sql.NewRow(int64(1), int64(1), "first row"), + sql.NewRow(int64(2), int64(2), "second row"), + sql.NewRow(int64(3), int64(3), "third row"), ) } else { t.Logf("Warning: could not create table %s: %s", "fk_tbl", err) @@ -688,9 +688,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil && ok { ctx := NewContext(harness) InsertRows(t, ctx, mustInsertableTable(t, autoTbl), - sql.NewRow(1, 11), - sql.NewRow(2, 22), - sql.NewRow(3, 33), + sql.NewRow(int64(1), int64(11)), + sql.NewRow(int64(2), int64(22)), + sql.NewRow(int64(3), int64(33)), ) // InsertRows bypasses integrator auto increment methods // manually set the auto increment value here @@ -713,9 +713,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil && ok { InsertRows(t, NewContext(harness), mustInsertableTable(t, autoTbl), - sql.NewRow(0, 2, 2), - sql.NewRow(1, 1, 0), - sql.NewRow(2, 0, 1), + sql.NewRow(int64(0), int64(2), int64(2)), + sql.NewRow(int64(1), int64(1), int64(0)), + sql.NewRow(int64(2), int64(0), int64(1)), ) } else { t.Logf("Warning: could not create table %s: %s", "invert_pk", err) @@ -734,16 +734,16 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), []sql.Row{ - {"pie", "crust", 1}, - {"pie", "filling", 2}, - {"crust", "flour", 20}, - {"crust", "sugar", 2}, - {"crust", "butter", 15}, - {"crust", "salt", 15}, - {"filling", "sugar", 5}, - {"filling", "fruit", 9}, - {"filling", "salt", 3}, - {"filling", "butter", 3}, + {"pie", "crust", int64(1)}, + {"pie", "filling", int64(2)}, + {"crust", "flour", int64(20)}, + {"crust", "sugar", int64(2)}, + {"crust", "butter", int64(15)}, + {"crust", "salt", int64(15)}, + {"filling", "sugar", int64(5)}, + {"filling", "fruit", int64(9)}, + {"filling", "salt", int64(3)}, + {"filling", "butter", int64(3)}, }...) } else { t.Logf("Warning: could not create table %s: %s", "parts", err) diff --git a/memory/table.go b/memory/table.go index 44d0373aeb..9de1a13b12 100644 --- a/memory/table.go +++ b/memory/table.go @@ -19,6 +19,7 @@ import ( "encoding/gob" "fmt" "io" + "reflect" "sort" "strconv" "strings" @@ -1451,3 +1452,18 @@ func (t *Table) PartitionRows2(ctx *sql.Context, partition sql.Partition) (sql.R return iter.(*tableIter), nil } + +func (t *Table) verifyRowTypes(row sql.Row) { + //TODO: only run this when in testing mode + if len(row) == len(t.schema.Schema) { + for i := range t.schema.Schema { + col := t.schema.Schema[i] + rowVal := row[i] + valType := reflect.TypeOf(rowVal) + expectedType := col.Type.ValueType() + if valType != expectedType && rowVal != nil && !valType.AssignableTo(expectedType) { + panic(fmt.Errorf("Actual Value Type: %s, Expected Value Type: %s", valType.String(), expectedType.String())) + } + } + } +} diff --git a/memory/table_editor.go b/memory/table_editor.go index ed2d169202..227da2897b 100644 --- a/memory/table_editor.go +++ b/memory/table_editor.go @@ -70,6 +70,7 @@ func (t *tableEditor) Insert(ctx *sql.Context, row sql.Row) error { if err := checkRow(t.table.schema.Schema, row); err != nil { return err } + t.table.verifyRowTypes(row) partitionRow, added, err := t.ea.Get(row) if err != nil { @@ -119,6 +120,7 @@ func (t *tableEditor) Delete(ctx *sql.Context, row sql.Row) error { if err := checkRow(t.table.schema.Schema, row); err != nil { return err } + t.table.verifyRowTypes(row) err := t.ea.Delete(row) if err != nil { @@ -136,6 +138,8 @@ func (t *tableEditor) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) e if err := checkRow(t.table.schema.Schema, newRow); err != nil { return err } + t.table.verifyRowTypes(oldRow) + t.table.verifyRowTypes(newRow) err := t.ea.Delete(oldRow) if err != nil { diff --git a/sql/arraytype.go b/sql/arraytype.go index 6d85cf9e96..afb87ab423 100644 --- a/sql/arraytype.go +++ b/sql/arraytype.go @@ -18,11 +18,14 @@ import ( "encoding/json" "fmt" "io" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var arrayValueType = reflect.TypeOf((*[]interface{})(nil)).Elem() + type arrayType struct { underlying Type } @@ -153,7 +156,7 @@ func (t arrayType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { } } - val = appendAndSlice(dest, val) + val = appendAndSliceBytes(dest, val) return sqltypes.MakeTrusted(sqltypes.TypeJSON, val), nil } @@ -166,6 +169,10 @@ func (t arrayType) Type() query.Type { return sqltypes.TypeJSON } +func (t arrayType) ValueType() reflect.Type { + return arrayValueType +} + func (t arrayType) Zero() interface{} { return nil } diff --git a/sql/bit.go b/sql/bit.go index 19cfa155c4..ca5d771a5c 100644 --- a/sql/bit.go +++ b/sql/bit.go @@ -17,7 +17,7 @@ package sql import ( "encoding/binary" "fmt" - "strconv" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -35,10 +35,12 @@ const ( var ( promotedBitType = MustCreateBitType(BitTypeMaxBits) errBeyondMaxBit = errors.NewKind("%v is beyond the maximum value that can be held by %v bits") + bitValueType = reflect.TypeOf(uint64(0)) ) -// Represents the BIT type. +// BitType represents the BIT type. // https://dev.mysql.com/doc/refman/8.0/en/bit-type.html +// The type of the returned value is uint64. type BitType interface { Type NumberOfBits() uint8 @@ -133,6 +135,11 @@ func (t bitType) Convert(v interface{}) (interface{}, error) { return nil, fmt.Errorf(`negative floats cannot become bit values`) } value = uint64(val) + case decimal.NullDecimal: + if !val.Valid { + return nil, nil + } + return t.Convert(val.Decimal) case decimal.Decimal: val = val.Round(0) if val.GreaterThan(dec_uint64_max) { @@ -190,10 +197,16 @@ func (t bitType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if err != nil { return sqltypes.Value{}, err } + bitVal := value.(uint64) - stop := len(dest) - dest = strconv.AppendUint(dest, value.(uint64), 10) - val := dest[stop:] + var data []byte + for i := uint64(0); i < uint64(t.numOfBits); i += 8 { + data = append(data, byte(bitVal>>i)) + } + for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { + data[i], data[j] = data[j], data[i] + } + val := appendAndSliceBytes(dest, data) return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } @@ -208,6 +221,11 @@ func (t bitType) Type() query.Type { return sqltypes.Bit } +// ValueType implements Type interface. +func (t bitType) ValueType() reflect.Type { + return bitValueType +} + // Zero implements Type interface. Returns a uint64 value. func (t bitType) Zero() interface{} { return uint64(0) diff --git a/sql/bit_test.go b/sql/bit_test.go index 48a390296d..8680e74fa6 100644 --- a/sql/bit_test.go +++ b/sql/bit_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" "time" @@ -128,6 +129,9 @@ func TestBitConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/core.go b/sql/core.go index 5994de0b2e..67dab90608 100644 --- a/sql/core.go +++ b/sql/core.go @@ -1037,7 +1037,7 @@ type ExternalStoredProcedureDetails struct { Schema Schema // Function is the implementation of the external stored procedure. All functions should have the following definition: // `func(*Context, ) (RowIter, error)`. The may be any of the following types: `bool`, - // `string`, `[]byte`, `int8`-`int64`, `uint8`-`uint64`, `float32`, `float64`, `time.Time`, or `decimal` + // `string`, `[]byte`, `int8`-`int64`, `uint8`-`uint64`, `float32`, `float64`, `time.Time`, or `Decimal` // (shopspring/decimal). The architecture-dependent types `int` and `uint` (without a number) are also supported. // It is valid to return a nil RowIter if there are no rows to be returned. // @@ -1046,7 +1046,7 @@ type ExternalStoredProcedureDetails struct { // // Values are converted to their nearest type before being passed in, following the conversion rules of their // related SQL types. The exceptions are `time.Time` (treated as a `DATETIME`), string (treated as a `LONGTEXT` with - // the default collation) and decimal (treated with a larger precision and scale). Take extra care when using decimal + // the default collation) and Decimal (treated with a larger precision and scale). Take extra care when using decimal // for an INOUT parameter, to ensure that the returned value fits the original's precision and scale, else an error // will occur. // diff --git a/sql/datetimetype.go b/sql/datetimetype.go index 928b41f7ad..c0ba3483eb 100644 --- a/sql/datetimetype.go +++ b/sql/datetimetype.go @@ -16,8 +16,11 @@ package sql import ( "math" + "reflect" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -87,10 +90,13 @@ var ( Datetime = MustCreateDatetimeType(sqltypes.Datetime) // Timestamp is an UNIX timestamp. Timestamp = MustCreateDatetimeType(sqltypes.Timestamp) + + datetimeValueType = reflect.TypeOf(time.Time{}) ) -// Represents DATE, DATETIME, and TIMESTAMP. +// DatetimeType represents DATE, DATETIME, and TIMESTAMP. // https://dev.mysql.com/doc/refman/8.0/en/datetime.html +// The type of the returned value is time.Time. type DatetimeType interface { Type ConvertWithoutRangeCheck(v interface{}) (time.Time, error) @@ -275,6 +281,16 @@ func (t datetimeType) ConvertWithoutRangeCheck(v interface{}) (time.Time, error) return zeroTime, nil } return zeroTime, ErrConvertingToTime.New(v) + case decimal.Decimal: + if value.IsZero() { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(v) + case decimal.NullDecimal: + if value.Valid && value.Decimal.IsZero() { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(v) default: return zeroTime, ErrConvertToSQL.New(t) } @@ -317,37 +333,37 @@ func (t datetimeType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { vt := v.(time.Time) var typ query.Type - var val []byte + var val string switch t.baseType { case sqltypes.Date: typ = sqltypes.Date if vt.Equal(zeroTime) { - val = []byte(vt.Format(zeroDateStr)) + val = vt.Format(zeroDateStr) } else { - val = []byte(vt.Format(DateLayout)) + val = vt.Format(DateLayout) } case sqltypes.Datetime: typ = sqltypes.Datetime if vt.Equal(zeroTime) { - val = []byte(vt.Format(zeroTimestampDatetimeStr)) + val = vt.Format(zeroTimestampDatetimeStr) } else { - val = []byte(vt.Format(TimestampDatetimeLayout)) + val = vt.Format(TimestampDatetimeLayout) } case sqltypes.Timestamp: typ = sqltypes.Timestamp if vt.Equal(zeroTime) { - val = []byte(vt.Format(zeroTimestampDatetimeStr)) + val = vt.Format(zeroTimestampDatetimeStr) } else { - val = []byte(vt.Format(TimestampDatetimeLayout)) + val = vt.Format(TimestampDatetimeLayout) } default: panic(ErrInvalidBaseType.New(t.baseType.String(), "datetime")) } - val = appendAndSlice(dest, val) + valBytes := appendAndSliceString(dest, val) - return sqltypes.MakeTrusted(typ, val), nil + return sqltypes.MakeTrusted(typ, valBytes), nil } func (t datetimeType) String() string { @@ -368,6 +384,11 @@ func (t datetimeType) Type() query.Type { return t.baseType } +// ValueType implements Type interface. +func (t datetimeType) ValueType() reflect.Type { + return datetimeValueType +} + func (t datetimeType) Zero() interface{} { return zeroTime } diff --git a/sql/datetimetype_test.go b/sql/datetimetype_test.go index 10716b52c3..266efbec1a 100644 --- a/sql/datetimetype_test.go +++ b/sql/datetimetype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" "time" @@ -304,6 +305,9 @@ func TestDatetimeConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/decimal.go b/sql/decimal.go index 6c0b4d16f0..146319c8e2 100644 --- a/sql/decimal.go +++ b/sql/decimal.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "math/big" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -37,14 +38,29 @@ var ( ErrConvertingToDecimal = errors.NewKind("value %v is not a valid Decimal") ErrConvertToDecimalLimit = errors.NewKind("value of Decimal is too large for type") ErrMarshalNullDecimal = errors.NewKind("Decimal cannot marshal a null value") + + decimalValueType = reflect.TypeOf(decimal.Decimal{}) ) +// DecimalType represents the DECIMAL type. +// https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html +// The type of the returned value is decimal.Decimal. type DecimalType interface { Type - ConvertToDecimal(v interface{}) (decimal.NullDecimal, error) + // ConvertToNullDecimal converts the given value to a decimal.NullDecimal if it has a compatible type. It is worth + // noting that Convert() returns a nil value for nil inputs, and also returns decimal.Decimal rather than + // decimal.NullDecimal. + ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) + // ExclusiveUpperBound returns the exclusive upper bound for this Decimal. + // For example, DECIMAL(5,2) would return 1000, as 999.99 is the max represented. ExclusiveUpperBound() decimal.Decimal + // MaximumScale returns the maximum scale allowed for the current precision. MaximumScale() uint8 + // Precision returns the base-10 precision of the type, which is the total number of digits. For example, a + // precision of 3 means that 999, 99.9, 9.99, and .999 are all valid maximums (depending on the scale). Precision() uint8 + // Scale returns the scale, or number of digits after the decimal, that may be held. + // This will always be less than or equal to the precision. Scale() uint8 } @@ -103,11 +119,11 @@ func (t decimalType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - af, err := t.ConvertToDecimal(a) + af, err := t.ConvertToNullDecimal(a) if err != nil { return 0, err } - bf, err := t.ConvertToDecimal(b) + bf, err := t.ConvertToNullDecimal(b) if err != nil { return 0, err } @@ -117,18 +133,18 @@ func (t decimalType) Compare(a interface{}, b interface{}) (int, error) { // Convert implements Type interface. func (t decimalType) Convert(v interface{}) (interface{}, error) { - dec, err := t.ConvertToDecimal(v) + dec, err := t.ConvertToNullDecimal(v) if err != nil { return nil, err } if !dec.Valid { return nil, nil } - return dec.Decimal.StringFixed(int32(t.scale)), nil + return dec.Decimal, nil } -// Precision returns the precision, or total number of digits, that may be held. -func (t decimalType) ConvertToDecimal(v interface{}) (decimal.NullDecimal, error) { +// ConvertToNullDecimal implements DecimalType interface. +func (t decimalType) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) { if v == nil { return decimal.NullDecimal{}, nil } @@ -137,21 +153,21 @@ func (t decimalType) ConvertToDecimal(v interface{}) (decimal.NullDecimal, error switch value := v.(type) { case int: - return t.ConvertToDecimal(int64(value)) + return t.ConvertToNullDecimal(int64(value)) case uint: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int8: - return t.ConvertToDecimal(int64(value)) + return t.ConvertToNullDecimal(int64(value)) case uint8: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int16: - return t.ConvertToDecimal(int64(value)) + return t.ConvertToNullDecimal(int64(value)) case uint16: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int32: res = decimal.NewFromInt32(value) case uint32: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int64: res = decimal.NewFromInt(value) case uint64: @@ -175,11 +191,11 @@ func (t decimalType) ConvertToDecimal(v interface{}) (decimal.NullDecimal, error } } case *big.Float: - return t.ConvertToDecimal(value.Text('f', -1)) + return t.ConvertToNullDecimal(value.Text('f', -1)) case *big.Int: - return t.ConvertToDecimal(value.Text(10)) + return t.ConvertToNullDecimal(value.Text(10)) case *big.Rat: - return t.ConvertToDecimal(new(big.Float).SetRat(value)) + return t.ConvertToNullDecimal(new(big.Float).SetRat(value)) case decimal.Decimal: res = value case decimal.NullDecimal: @@ -232,7 +248,7 @@ func (t decimalType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(value.(string))) + val := appendAndSliceString(dest, value.(decimal.Decimal).StringFixed(int32(t.scale))) return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil } @@ -242,18 +258,22 @@ func (t decimalType) String() string { return fmt.Sprintf("DECIMAL(%v,%v)", t.precision, t.scale) } -// Zero implements Type interface. Returns a uint64 value. +// ValueType implements Type interface. +func (t decimalType) ValueType() reflect.Type { + return decimalValueType +} + +// Zero implements Type interface. func (t decimalType) Zero() interface{} { return decimal.NewFromInt(0).StringFixed(int32(t.scale)) } -// ExclusiveUpperBound returns the exclusive upper bound for this Decimal. -// For example, DECIMAL(5,2) would return 1000, as 999.99 is the max represented. +// ExclusiveUpperBound implements DecimalType interface. func (t decimalType) ExclusiveUpperBound() decimal.Decimal { return t.exclusiveUpperBound } -// MaximumScale returns the maximum scale allowed for the current precision. +// MaximumScale implements DecimalType interface. func (t decimalType) MaximumScale() uint8 { if t.precision >= DecimalTypeMaxScale { return DecimalTypeMaxScale @@ -261,14 +281,12 @@ func (t decimalType) MaximumScale() uint8 { return t.precision } -// Precision returns the base-10 precision of the type, which is the total number of digits. -// For example, a precision of 3 means that 999, 99.9, 9.99, and .999 are all valid maximums (depending on the scale). +// Precision implements DecimalType interface. func (t decimalType) Precision() uint8 { return t.precision } -// Scale returns the scale, or number of digits after the decimal, that may be held. -// This will always be less than or equal to the precision. +// Scale implements DecimalType interface. func (t decimalType) Scale() uint8 { return t.scale } diff --git a/sql/decimal_test.go b/sql/decimal_test.go index 3dcb0ff86f..fa4b6bff8d 100644 --- a/sql/decimal_test.go +++ b/sql/decimal_test.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "math/big" + "reflect" "strings" "testing" "time" @@ -69,7 +70,7 @@ func TestDecimalAccuracy(t *testing.T) { for _, test := range tests { decimalType := MustCreateDecimalType(uint8(precision), uint8(test.scale)) - decimal := big.NewInt(0) + decimalInt := big.NewInt(0) bigIntervals := make([]*big.Int, len(test.intervals)) for i, interval := range test.intervals { bigInterval := new(big.Int) @@ -81,18 +82,18 @@ func TestDecimalAccuracy(t *testing.T) { upperBound := new(big.Int) _ = upperBound.UnmarshalText([]byte("1" + strings.Repeat("0", test.scale))) - for decimal.Cmp(upperBound) == -1 { - decimalStr := decimal.Text(10) + for decimalInt.Cmp(upperBound) == -1 { + decimalStr := decimalInt.Text(10) fullDecimalStr := strings.Repeat("0", test.scale-len(decimalStr)) + decimalStr fullStr := baseStr + fullDecimalStr t.Run(fmt.Sprintf("Scale:%v DecVal:%v", test.scale, fullDecimalStr), func(t *testing.T) { res, err := decimalType.Convert(fullStr) require.NoError(t, err) - require.Equal(t, fullStr, res) + require.Equal(t, fullStr, res.(decimal.Decimal).StringFixed(int32(decimalType.Scale()))) }) - decimal.Add(decimal, bigIntervals[intervalIndex]) + decimalInt.Add(decimalInt, bigIntervals[intervalIndex]) intervalIndex = (intervalIndex + 1) % len(bigIntervals) } } @@ -267,12 +268,20 @@ func TestDecimalConvert(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%v %v %v", test.precision, test.scale, test.val), func(t *testing.T) { - val, err := MustCreateDecimalType(test.precision, test.scale).Convert(test.val) + typ := MustCreateDecimalType(test.precision, test.scale) + val, err := typ.Convert(test.val) if test.expectedErr { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) + if test.expectedVal == nil { + assert.Nil(t, val) + } else { + expectedVal, err := decimal.NewFromString(test.expectedVal.(string)) + require.NoError(t, err) + assert.True(t, expectedVal.Equal(val.(decimal.Decimal))) + assert.Equal(t, typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/deferredtype.go b/sql/deferredtype.go index 36ae340398..b793d106a0 100644 --- a/sql/deferredtype.go +++ b/sql/deferredtype.go @@ -15,6 +15,8 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -85,6 +87,11 @@ func (t deferredType) Type() query.Type { return sqltypes.Expression } +// ValueType implements Type interface. +func (t deferredType) ValueType() reflect.Type { + return nil +} + // Zero implements Type interface. func (t deferredType) Zero() interface{} { return nil diff --git a/sql/enumtype.go b/sql/enumtype.go index 53f5d8ce3d..c6b57ae316 100644 --- a/sql/enumtype.go +++ b/sql/enumtype.go @@ -16,9 +16,12 @@ package sql import ( "fmt" + "reflect" "strconv" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -35,23 +38,26 @@ const ( var ( ErrConvertingToEnum = errors.NewKind("value %v is not valid for this Enum") ErrUnmarshallingEnum = errors.NewKind("value %v is not a marshalled value for this Enum") + + enumValueType = reflect.TypeOf(uint16(0)) ) // Comments with three slashes were taken directly from the linked documentation. -// Represents the ENUM type. +// EnumType represents the ENUM type. // https://dev.mysql.com/doc/refman/8.0/en/enum.html +// The type of the returned value is uint16. type EnumType interface { Type + // At returns the string at the given index, as well if the string was found. At(index int) (string, bool) CharacterSet() CharacterSet Collation() Collation - ConvertToIndex(v interface{}) (int, error) + // IndexOf returns the index of the given string. If the string was not found, then this returns -1. IndexOf(v string) int - //TODO: move this out of go-mysql-server and into the Dolt layer - Marshal(v interface{}) (int64, error) + // NumberOfElements returns the number of enumerations. NumberOfElements() uint16 - Unmarshal(v int64) (string, error) + // Values returns the elements, in order, of every enumeration. Values() []string } @@ -104,18 +110,20 @@ func (t enumType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - ai, err := t.ConvertToIndex(a) + ai, err := t.Convert(a) if err != nil { return 0, err } - bi, err := t.ConvertToIndex(b) + bi, err := t.Convert(b) if err != nil { return 0, err } + au := ai.(uint16) + bu := bi.(uint16) - if ai < bi { + if au < bu { return -1, nil - } else if ai > bi { + } else if au > bu { return 1, nil } return 0, nil @@ -129,8 +137,8 @@ func (t enumType) Convert(v interface{}) (interface{}, error) { switch value := v.(type) { case int: - if str, ok := t.At(value); ok { - return str, nil + if _, ok := t.At(value); ok { + return uint16(value), nil } case uint: return t.Convert(int(value)) @@ -154,12 +162,17 @@ func (t enumType) Convert(v interface{}) (interface{}, error) { return t.Convert(int(value)) case float64: return t.Convert(int(value)) + case decimal.Decimal: + return t.Convert(value.IntPart()) + case decimal.NullDecimal: + if !value.Valid { + return nil, nil + } + return t.Convert(value.Decimal.IntPart()) case string: if index := t.IndexOf(value); index != -1 { - realStr, _ := t.At(index) - return realStr, nil + return uint16(index), nil } - return nil, ErrConvertingToEnum.New(`"` + value + `"`) case []byte: return t.Convert(string(value)) } @@ -189,47 +202,6 @@ func (t enumType) Equals(otherType Type) bool { return false } -// ConvertToIndex is similar to Convert, except that it converts to the index rather than the value. -// Returns an error on nil. -func (t enumType) ConvertToIndex(v interface{}) (int, error) { - switch value := v.(type) { - case int: - if _, ok := t.At(value); ok { - return value, nil - } - case uint: - return t.ConvertToIndex(int(value)) - case int8: - return t.ConvertToIndex(int(value)) - case uint8: - return t.ConvertToIndex(int(value)) - case int16: - return t.ConvertToIndex(int(value)) - case uint16: - return t.ConvertToIndex(int(value)) - case int32: - return t.ConvertToIndex(int(value)) - case uint32: - return t.ConvertToIndex(int(value)) - case int64: - return t.ConvertToIndex(int(value)) - case uint64: - return t.ConvertToIndex(int(value)) - case float32: - return t.ConvertToIndex(int(value)) - case float64: - return t.ConvertToIndex(int(value)) - case string: - if index := t.IndexOf(value); index != -1 { - return index, nil - } - case []byte: - return t.ConvertToIndex(string(value)) - } - - return -1, ErrConvertingToEnum.New(v) -} - // Promote implements the Type interface. func (t enumType) Promote() Type { return t @@ -240,12 +212,13 @@ func (t enumType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } - value, err := t.Convert(v) + convertedValue, err := t.Convert(v) if err != nil { return sqltypes.Value{}, err } + value, _ := t.At(int(convertedValue.(uint16))) - val := appendAndSlice(dest, []byte(value.(string))) + val := appendAndSliceString(dest, value) return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } @@ -267,13 +240,18 @@ func (t enumType) Type() query.Type { return sqltypes.Enum } +// ValueType implements Type interface. +func (t enumType) ValueType() reflect.Type { + return enumValueType +} + // Zero implements Type interface. func (t enumType) Zero() interface{} { /// If an ENUM column is declared NOT NULL, its default value is the first element of the list of permitted values. return t.indexToVal[0] } -// At returns the string at the given index, as well if the string was found. +// At implements EnumType interface. func (t enumType) At(index int) (string, bool) { /// The elements listed in the column specification are assigned index numbers, beginning with 1. index -= 1 @@ -283,15 +261,17 @@ func (t enumType) At(index int) (string, bool) { return t.indexToVal[index], true } +// CharacterSet implements EnumType interface. func (t enumType) CharacterSet() CharacterSet { return t.collation.CharacterSet() } +// Collation implements EnumType interface. func (t enumType) Collation() Collation { return t.collation } -// IndexOf returns the index of the given string. If the string was not found, then this returns -1. +// IndexOf implements EnumType interface. func (t enumType) IndexOf(v string) int { if index, ok := t.valToIndex[v]; ok { return index @@ -308,27 +288,12 @@ func (t enumType) IndexOf(v string) int { return -1 } -// Marshal takes a valid Enum value and returns it as an int64. -func (t enumType) Marshal(v interface{}) (int64, error) { - i, err := t.ConvertToIndex(v) - return int64(i), err -} - -// NumberOfElements returns the number of enumerations. +// NumberOfElements implements EnumType interface. func (t enumType) NumberOfElements() uint16 { return uint16(len(t.indexToVal)) } -// Unmarshal takes a previously-marshalled value and returns it as a string. -func (t enumType) Unmarshal(v int64) (string, error) { - str, found := t.At(int(v)) - if !found { - return "", ErrUnmarshallingEnum.New(v) - } - return str, nil -} - -// Values returns the elements, in order, of every enumeration. +// Values implements EnumType interface. func (t enumType) Values() []string { vals := make([]string, len(t.indexToVal)) copy(vals, t.indexToVal) diff --git a/sql/enumtype_test.go b/sql/enumtype_test.go index e44880d53c..d5d640f3bf 100644 --- a/sql/enumtype_test.go +++ b/sql/enumtype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strconv" "testing" "time" @@ -144,15 +145,15 @@ func TestEnumConvert(t *testing.T) { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) if test.val != nil { - mar, err := typ.Marshal(test.val) - require.NoError(t, err) - umar, err := typ.Unmarshal(mar) - require.NoError(t, err) + umar, ok := typ.At(int(val.(uint16))) + require.True(t, ok) cmp, err := typ.Compare(test.val, umar) require.NoError(t, err) assert.Equal(t, 0, cmp) + assert.Equal(t, typ.ValueType(), reflect.TypeOf(val)) + } else { + assert.Equal(t, test.expectedVal, val) } } }) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 17b6cd0fe1..f00f097a35 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -62,11 +62,12 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) { return c.Left().Type().Compare(left, right) } - // ENUM and SET must be considered when doing comparisons, as they can match arbitrary strings to numbers based on - // their elements. For other types it seems there are other considerations, therefore we only take the type for - // ENUM and SET, and default to direct literal comparisons for all other types. Eventually we will need to make our - // comparisons context-sensitive, as all comparisons should probably be based on the column/variable if present. - // Until then, this is a workaround specifically for ENUM and SET. + // ENUM, SET, and TIME must be excluded when doing comparisons, as they're too restrictive to use as a comparison + // base. + // + // The best overall method would be to assign type priority. For example, INT would have a higher priority than + // TINYINT. This could then be combined with the origin of the value (table column, procedure param, etc.) to + // determine the best type for any comparison (tie-breakers can be simple rules such as the current left preference). var compareType sql.Type switch c.Left().(type) { case *GetField, *UserVar, *SystemVar, *ProcedureParam: @@ -80,7 +81,8 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) { if compareType != nil { _, isEnum := compareType.(sql.EnumType) _, isSet := compareType.(sql.SetType) - if !isEnum && !isSet { + _, isTime := compareType.(sql.TimeType) + if !isEnum && !isSet && !isTime { compareType = nil } } @@ -115,11 +117,12 @@ func (c *comparison) NullSafeCompare(ctx *sql.Context, row sql.Row) (int, error) return c.Left().Type().Compare(left, right) } - // ENUM and SET must be considered when doing comparisons, as they can match arbitrary strings to numbers based on - // their elements. For other types it seems there are other considerations, therefore we only take the type for - // ENUM and SET, and default to direct literal comparisons for all other types. Eventually we will need to make our - // comparisons context-sensitive, as all comparisons should probably be based on the column/variable if present. - // Until then, this is a workaround specifically for ENUM and SET. + // ENUM, SET, and TIME must be excluded when doing comparisons, as they're too restrictive to use as a comparison + // base. + // + // The best overall method would be to assign type priority. For example, INT would have a higher priority than + // TINYINT. This could then be combined with the origin of the value (table column, procedure param, etc.) to + // determine the best type for any comparison (tie-breakers can be simple rules such as the current left preference). var compareType sql.Type switch c.Left().(type) { case *GetField, *UserVar, *SystemVar, *ProcedureParam: @@ -133,7 +136,8 @@ func (c *comparison) NullSafeCompare(ctx *sql.Context, row sql.Row) (int, error) if compareType != nil { _, isEnum := compareType.(sql.EnumType) _, isSet := compareType.(sql.SetType) - if !isEnum && !isSet { + _, isTime := compareType.(sql.TimeType) + if !isEnum && !isSet && !isTime { compareType = nil } } diff --git a/sql/expression/function/aggregation/window_partition_test.go b/sql/expression/function/aggregation/window_partition_test.go index 5adda99411..ecc4fcabea 100644 --- a/sql/expression/function/aggregation/window_partition_test.go +++ b/sql/expression/function/aggregation/window_partition_test.go @@ -96,15 +96,15 @@ func mustNewRowIter(t *testing.T, ctx *sql.Context) sql.RowIter { table := memory.NewTable("test", childSchema, nil) rows := []sql.Row{ - {int64(1), "forest", "leaf", 4}, - {int64(2), "forest", "bark", 4}, - {int64(3), "forest", "canopy", 6}, - {int64(4), "forest", "bug", 3}, - {int64(5), "forest", "wildflower", 10}, - {int64(6), "desert", "sand", 4}, - {int64(7), "desert", "cactus", 6}, - {int64(8), "desert", "scorpion", 8}, - {int64(9), "desert", "mummy", 5}, + {int64(1), "forest", "leaf", int32(4)}, + {int64(2), "forest", "bark", int32(4)}, + {int64(3), "forest", "canopy", int32(6)}, + {int64(4), "forest", "bug", int32(3)}, + {int64(5), "forest", "wildflower", int32(10)}, + {int64(6), "desert", "sand", int32(4)}, + {int64(7), "desert", "cactus", int32(6)}, + {int64(8), "desert", "scorpion", int32(8)}, + {int64(9), "desert", "mummy", int32(5)}, } for _, r := range rows { @@ -135,15 +135,15 @@ func TestWindowPartition_MaterializeInput(t *testing.T) { buf, ordering, err := i.materializeInput(ctx) require.NoError(t, err) expBuf := []sql.Row{ - {int64(1), "forest", "leaf", 4}, - {int64(2), "forest", "bark", 4}, - {int64(3), "forest", "canopy", 6}, - {int64(4), "forest", "bug", 3}, - {int64(5), "forest", "wildflower", 10}, - {int64(6), "desert", "sand", 4}, - {int64(7), "desert", "cactus", 6}, - {int64(8), "desert", "scorpion", 8}, - {int64(9), "desert", "mummy", 5}, + {int64(1), "forest", "leaf", int32(4)}, + {int64(2), "forest", "bark", int32(4)}, + {int64(3), "forest", "canopy", int32(6)}, + {int64(4), "forest", "bug", int32(3)}, + {int64(5), "forest", "wildflower", int32(10)}, + {int64(6), "desert", "sand", int32(4)}, + {int64(7), "desert", "cactus", int32(6)}, + {int64(8), "desert", "scorpion", int32(8)}, + {int64(9), "desert", "mummy", int32(5)}, } require.ElementsMatch(t, expBuf, buf) expOrd := []int{0, 1, 2, 3, 4, 5, 6, 7, 8} @@ -157,15 +157,15 @@ func TestWindowPartition_InitializePartitions(t *testing.T) { PartitionBy: partitionByX, }) i.input = []sql.Row{ - {int64(1), "forest", "leaf", 4}, - {int64(2), "forest", "bark", 4}, - {int64(3), "forest", "canopy", 6}, - {int64(4), "forest", "bug", 3}, - {int64(5), "forest", "wildflower", 10}, - {int64(6), "desert", "sand", 4}, - {int64(7), "desert", "cactus", 6}, - {int64(8), "desert", "scorpion", 8}, - {int64(9), "desert", "mummy", 5}, + {int64(1), "forest", "leaf", int32(4)}, + {int64(2), "forest", "bark", int32(4)}, + {int64(3), "forest", "canopy", int32(6)}, + {int64(4), "forest", "bug", int32(3)}, + {int64(5), "forest", "wildflower", int32(10)}, + {int64(6), "desert", "sand", int32(4)}, + {int64(7), "desert", "cactus", int32(6)}, + {int64(8), "desert", "scorpion", int32(8)}, + {int64(9), "desert", "mummy", int32(5)}, } partitions, err := i.initializePartitions(ctx) require.NoError(t, err) diff --git a/sql/expression/function/timediff.go b/sql/expression/function/timediff.go index ac4eb6dd0c..cd7cca3989 100644 --- a/sql/expression/function/timediff.go +++ b/sql/expression/function/timediff.go @@ -98,13 +98,12 @@ func (td *TimeDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { rightDatetime = rightDatetime.In(leftDatetime.Location()) } return sql.Time.Convert(leftDatetime.Sub(rightDatetime)) - } else if leftTime, err := sql.Time.ConvertToTimeDuration(left); err == nil { - rightTime, err := sql.Time.ConvertToTimeDuration(right) + } else if leftTime, err := sql.Time.ConvertToTimespan(left); err == nil { + rightTime, err := sql.Time.ConvertToTimespan(right) if err != nil { return nil, err } - resTime := leftTime - rightTime - return sql.Time.Convert(resTime) + return leftTime.Subtract(rightTime), nil } else { return nil, ErrInvalidArgumentType.New("timediff") } diff --git a/sql/expression/function/timediff_test.go b/sql/expression/function/timediff_test.go index de8830b3db..00d8763374 100644 --- a/sql/expression/function/timediff_test.go +++ b/sql/expression/function/timediff_test.go @@ -26,19 +26,27 @@ import ( ) func TestTimeDiff(t *testing.T) { + toTimespan := func(str string) sql.Timespan { + res, err := sql.Time.ConvertToTimespan(str) + if err != nil { + t.Fatal(err) + } + return res + } + ctx := sql.NewEmptyContext() testCases := []struct { name string from sql.Expression to sql.Expression - expected string + expected sql.Timespan err bool }{ { "invalid type text", expression.NewLiteral("hello there", sql.Text), expression.NewConvert(expression.NewLiteral("01:00:00", sql.Text), expression.ConvertToTime), - "", + toTimespan(""), true, }, //TODO: handle Date properly @@ -53,70 +61,70 @@ func TestTimeDiff(t *testing.T) { "type mismatch 1", expression.NewLiteral(time.Date(2008, time.December, 29, 1, 1, 1, 2, time.Local), sql.Timestamp), expression.NewConvert(expression.NewLiteral("01:00:00", sql.Text), expression.ConvertToTime), - "", + toTimespan(""), true, }, { "type mismatch 2", expression.NewLiteral("00:00:00.2", sql.Text), expression.NewLiteral("2020-10-10 10:10:10", sql.Text), - "", + toTimespan(""), true, }, { "valid mismatch", expression.NewLiteral(time.Date(2008, time.December, 29, 1, 1, 1, 2, time.Local), sql.Timestamp), expression.NewLiteral(time.Date(2008, time.December, 30, 1, 1, 1, 2, time.Local), sql.Datetime), - "-24:00:00", + toTimespan("-24:00:00"), false, }, { "timestamp types 1", expression.NewLiteral(time.Date(2018, time.May, 2, 0, 0, 0, 0, time.Local), sql.Timestamp), expression.NewLiteral(time.Date(2018, time.May, 2, 0, 0, 1, 0, time.Local), sql.Timestamp), - "-00:00:01", + toTimespan("-00:00:01"), false, }, { "timestamp types 2", expression.NewLiteral(time.Date(2008, time.December, 31, 23, 59, 59, 1, time.Local), sql.Timestamp), expression.NewLiteral(time.Date(2008, time.December, 30, 1, 1, 1, 2, time.Local), sql.Timestamp), - "46:58:57.999999", + toTimespan("46:58:57.999999"), false, }, { "time types 1", expression.NewConvert(expression.NewLiteral("00:00:00.1", sql.Text), expression.ConvertToTime), expression.NewConvert(expression.NewLiteral("00:00:00.2", sql.Text), expression.ConvertToTime), - "-00:00:00.100000", + toTimespan("-00:00:00.100000"), false, }, { "time types 2", expression.NewLiteral("00:00:00.2", sql.Text), expression.NewLiteral("00:00:00.4", sql.Text), - "-00:00:00.200000", + toTimespan("-00:00:00.200000"), false, }, { "datetime types", expression.NewLiteral(time.Date(2008, time.December, 29, 0, 0, 0, 0, time.Local), sql.Datetime), expression.NewLiteral(time.Date(2008, time.December, 30, 0, 0, 0, 0, time.Local), sql.Datetime), - "-24:00:00", + toTimespan("-24:00:00"), false, }, { "datetime string types", expression.NewLiteral("2008-12-29 00:00:00", sql.Text), expression.NewLiteral("2008-12-30 00:00:00", sql.Text), - "-24:00:00", + toTimespan("-24:00:00"), false, }, { "datetime string mix types", expression.NewLiteral(time.Date(2008, time.December, 29, 0, 0, 0, 0, time.UTC), sql.Datetime), expression.NewLiteral("2008-12-30 00:00:00", sql.Text), - "-24:00:00", + toTimespan("-24:00:00"), false, }, } diff --git a/sql/geometry.go b/sql/geometry.go index 68c6316721..b05b818f86 100644 --- a/sql/geometry.go +++ b/sql/geometry.go @@ -17,21 +17,35 @@ package sql import ( "encoding/binary" "math" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" ) +// GeometryType represents the GEOMETRY type. +// https://dev.mysql.com/doc/refman/8.0/en/gis-class-geometry.html +// The type of the returned value is one of the following (each implements GeometryValue): Point, Polygon, LineString. type GeometryType struct { SRID uint32 DefinedSRID bool } +// GeometryValue is the value type returned from GeometryType, which is an interface over the following types: +// Point, Polygon, LineString. +type GeometryValue interface { + implementsGeometryValue() +} + var _ Type = GeometryType{} var _ SpatialColumnType = GeometryType{} -var ErrNotGeometry = errors.NewKind("Value of type %T is not a geometry") +var ( + ErrNotGeometry = errors.NewKind("Value of type %T is not a geometry") + + geometryValueType = reflect.TypeOf((*GeometryValue)(nil)).Elem() +) const ( CartesianSRID = uint32(0) @@ -262,7 +276,8 @@ func (t GeometryType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, nil } - val := appendAndSlice(dest, []byte(pv.(string))) + //TODO: pretty sure this is wrong, pv is not a string type + val := appendAndSliceString(dest, pv.(string)) return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } @@ -277,6 +292,11 @@ func (t GeometryType) Type() query.Type { return sqltypes.Geometry } +// ValueType implements Type interface. +func (t GeometryType) ValueType() reflect.Type { + return geometryValueType +} + // Zero implements Type interface. func (t GeometryType) Zero() interface{} { // TODO: it doesn't make sense for geometry to have a zero type diff --git a/sql/json.go b/sql/json.go index d663aeaa6a..f29742de1b 100644 --- a/sql/json.go +++ b/sql/json.go @@ -16,16 +16,24 @@ package sql import ( "encoding/json" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" ) -var ErrConvertingToJSON = errors.NewKind("value %v is not valid JSON") +var ( + ErrConvertingToJSON = errors.NewKind("value %v is not valid JSON") + + jsonValueType = reflect.TypeOf((*JSONValue)(nil)).Elem() +) var JSON JsonType = jsonType{} +// JsonType represents the JSON type. +// https://dev.mysql.com/doc/refman/8.0/en/json.html +// The type of the returned value is JSONValue. type JsonType interface { Type } @@ -97,7 +105,7 @@ func (t jsonType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.NULL, err } - val := appendAndSlice(dest, []byte(s)) + val := appendAndSliceString(dest, s) return sqltypes.MakeTrusted(sqltypes.TypeJSON, val), nil } @@ -112,6 +120,11 @@ func (t jsonType) Type() query.Type { return sqltypes.TypeJSON } +// ValueType implements Type interface. +func (t jsonType) ValueType() reflect.Type { + return jsonValueType +} + // Zero implements Type interface. func (t jsonType) Zero() interface{} { // JSON Null diff --git a/sql/json_test.go b/sql/json_test.go index 6399107655..25deca792a 100644 --- a/sql/json_test.go +++ b/sql/json_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" querypb "github.com/dolthub/vitess/go/vt/proto/query" @@ -122,6 +123,9 @@ func TestJsonConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.True(t, reflect.TypeOf(val).Implements(JSON.ValueType())) + } } }) } diff --git a/sql/linestring.go b/sql/linestring.go index faef48dc8b..9c70597ab9 100644 --- a/sql/linestring.go +++ b/sql/linestring.go @@ -15,28 +15,37 @@ package sql import ( + "reflect" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) -// Represents the LineString type. +// LineStringType represents the LINESTRING type. // https://dev.mysql.com/doc/refman/8.0/en/gis-class-linestring.html -type LineString struct { - SRID uint32 - Points []Point -} - +// The type of the returned value is LineString. type LineStringType struct { SRID uint32 DefinedSRID bool } +// LineString is the value type returned from LineStringType. Implements GeometryValue. +type LineString struct { + SRID uint32 + Points []Point +} + var _ Type = LineStringType{} var _ SpatialColumnType = LineStringType{} +var _ GeometryValue = LineString{} + +var ( + ErrNotLineString = errors.NewKind("value of type %T is not a linestring") -var ErrNotLineString = errors.NewKind("value of type %T is not a linestring") + lineStringValueType = reflect.TypeOf(LineString{}) +) // Compare implements Type interface. func (t LineStringType) Compare(a interface{}, b interface{}) (int, error) { @@ -146,7 +155,8 @@ func (t LineStringType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) return sqltypes.Value{}, nil } - val := appendAndSlice(dest, []byte(pv.(string))) + //TODO: pretty sure this is wrong, pv is not a string type + val := appendAndSliceString(dest, pv.(string)) return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } @@ -161,6 +171,11 @@ func (t LineStringType) Type() query.Type { return sqltypes.Geometry } +// ValueType implements Type interface. +func (t LineStringType) ValueType() reflect.Type { + return lineStringValueType +} + // Zero implements Type interface. func (t LineStringType) Zero() interface{} { return LineString{Points: []Point{{}, {}}} @@ -191,3 +206,6 @@ func (t LineStringType) MatchSRID(v interface{}) error { } return ErrNotMatchingSRID.New(val.SRID, t.SRID) } + +// implementsGeometryValue implements GeometryValue interface. +func (p LineString) implementsGeometryValue() {} diff --git a/sql/mysql_db/db_table.go b/sql/mysql_db/db_table.go index 764f5b5c3d..13dfa669d6 100644 --- a/sql/mysql_db/db_table.go +++ b/sql/mysql_db/db_table.go @@ -74,7 +74,7 @@ func (conv DbConverter) AddRowToEntry(ctx *sql.Context, row sql.Row, entry in_me } var privs []sql.PrivilegeType for i, val := range row { - if strVal, ok := val.(string); ok && strVal == "Y" { + if uintVal, ok := val.(uint16); ok && uintVal == 2 { switch i { case dbTblColIndex_Select_priv: privs = append(privs, sql.PrivilegeType_Select) @@ -167,43 +167,43 @@ func (conv DbConverter) EntryToRows(ctx *sql.Context, entry in_mem_table.Entry) for _, priv := range dbSet.ToSlice() { switch priv { case sql.PrivilegeType_Select: - row[dbTblColIndex_Select_priv] = "Y" + row[dbTblColIndex_Select_priv] = uint16(2) case sql.PrivilegeType_Insert: - row[dbTblColIndex_Insert_priv] = "Y" + row[dbTblColIndex_Insert_priv] = uint16(2) case sql.PrivilegeType_Update: - row[dbTblColIndex_Update_priv] = "Y" + row[dbTblColIndex_Update_priv] = uint16(2) case sql.PrivilegeType_Delete: - row[dbTblColIndex_Delete_priv] = "Y" + row[dbTblColIndex_Delete_priv] = uint16(2) case sql.PrivilegeType_Create: - row[dbTblColIndex_Create_priv] = "Y" + row[dbTblColIndex_Create_priv] = uint16(2) case sql.PrivilegeType_Drop: - row[dbTblColIndex_Drop_priv] = "Y" + row[dbTblColIndex_Drop_priv] = uint16(2) case sql.PrivilegeType_Grant: - row[dbTblColIndex_Grant_priv] = "Y" + row[dbTblColIndex_Grant_priv] = uint16(2) case sql.PrivilegeType_References: - row[dbTblColIndex_References_priv] = "Y" + row[dbTblColIndex_References_priv] = uint16(2) case sql.PrivilegeType_Index: - row[dbTblColIndex_Index_priv] = "Y" + row[dbTblColIndex_Index_priv] = uint16(2) case sql.PrivilegeType_Alter: - row[dbTblColIndex_Alter_priv] = "Y" + row[dbTblColIndex_Alter_priv] = uint16(2) case sql.PrivilegeType_CreateTempTable: - row[dbTblColIndex_Create_tmp_table_priv] = "Y" + row[dbTblColIndex_Create_tmp_table_priv] = uint16(2) case sql.PrivilegeType_LockTables: - row[dbTblColIndex_Lock_tables_priv] = "Y" + row[dbTblColIndex_Lock_tables_priv] = uint16(2) case sql.PrivilegeType_CreateView: - row[dbTblColIndex_Create_view_priv] = "Y" + row[dbTblColIndex_Create_view_priv] = uint16(2) case sql.PrivilegeType_ShowView: - row[dbTblColIndex_Show_view_priv] = "Y" + row[dbTblColIndex_Show_view_priv] = uint16(2) case sql.PrivilegeType_CreateRoutine: - row[dbTblColIndex_Create_routine_priv] = "Y" + row[dbTblColIndex_Create_routine_priv] = uint16(2) case sql.PrivilegeType_AlterRoutine: - row[dbTblColIndex_Alter_routine_priv] = "Y" + row[dbTblColIndex_Alter_routine_priv] = uint16(2) case sql.PrivilegeType_Execute: - row[dbTblColIndex_Execute_priv] = "Y" + row[dbTblColIndex_Execute_priv] = uint16(2) case sql.PrivilegeType_Event: - row[dbTblColIndex_Event_priv] = "Y" + row[dbTblColIndex_Event_priv] = uint16(2) case sql.PrivilegeType_Trigger: - row[dbTblColIndex_Trigger_priv] = "Y" + row[dbTblColIndex_Trigger_priv] = uint16(2) } } rows = append(rows, row) diff --git a/sql/mysql_db/role_edge.go b/sql/mysql_db/role_edge.go index 363c4c6584..02361a35d0 100644 --- a/sql/mysql_db/role_edge.go +++ b/sql/mysql_db/role_edge.go @@ -44,7 +44,7 @@ func (r *RoleEdge) NewFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Entry FromUser: row[roleEdgesTblColIndex_FROM_USER].(string), ToHost: row[roleEdgesTblColIndex_TO_HOST].(string), ToUser: row[roleEdgesTblColIndex_TO_USER].(string), - WithAdminOption: row[roleEdgesTblColIndex_WITH_ADMIN_OPTION].(string) == "Y", + WithAdminOption: row[roleEdgesTblColIndex_WITH_ADMIN_OPTION].(uint16) == 2, }, nil } @@ -61,9 +61,9 @@ func (r *RoleEdge) ToRow(ctx *sql.Context) sql.Row { row[roleEdgesTblColIndex_TO_HOST] = r.ToHost row[roleEdgesTblColIndex_TO_USER] = r.ToUser if r.WithAdminOption { - row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = "Y" + row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = uint16(2) } else { - row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = "N" + row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = uint16(1) } return row } diff --git a/sql/mysql_db/tables_priv.go b/sql/mysql_db/tables_priv.go index 235b7ae6f8..783f25a1d9 100644 --- a/sql/mysql_db/tables_priv.go +++ b/sql/mysql_db/tables_priv.go @@ -78,12 +78,16 @@ func (conv TablesPrivConverter) AddRowToEntry(ctx *sql.Context, row sql.Row, ent if !ok { return nil, errTablesPrivRow } - tablePrivs, ok := row[tablesPrivTblColIndex_Table_priv].(string) + tablePrivs, ok := row[tablesPrivTblColIndex_Table_priv].(uint64) if !ok { return nil, errTablesPrivRow } + tablePrivStrs, err := tablesPrivTblSchema[tablesPrivTblColIndex_Table_priv].Type.(sql.SetType).BitsToString(tablePrivs) + if err != nil { + return nil, err + } var privs []sql.PrivilegeType - for _, val := range strings.Split(tablePrivs, ",") { + for _, val := range strings.Split(tablePrivStrs, ",") { switch val { case "Select": privs = append(privs, sql.PrivilegeType_Select) @@ -204,7 +208,7 @@ func (conv TablesPrivConverter) EntryToRows(ctx *sql.Context, entry in_mem_table if err != nil { return nil, err } - row[tablesPrivTblColIndex_Table_priv] = formattedSet.(string) + row[tablesPrivTblColIndex_Table_priv] = formattedSet.(uint64) rows = append(rows, row) } } diff --git a/sql/mysql_db/user.go b/sql/mysql_db/user.go index 334f18c3cd..fc5f121a9a 100644 --- a/sql/mysql_db/user.go +++ b/sql/mysql_db/user.go @@ -64,7 +64,7 @@ func (u *User) NewFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Entry, er Plugin: row[userTblColIndex_plugin].(string), Password: row[userTblColIndex_authentication_string].(string), PasswordLastChanged: passwordLastChanged, - Locked: row[userTblColIndex_account_locked].(string) == "Y", + Locked: row[userTblColIndex_account_locked].(uint16) == 2, Attributes: attributes, IsRole: false, }, nil @@ -97,7 +97,7 @@ func (u *User) ToRow(ctx *sql.Context) sql.Row { row[userTblColIndex_authentication_string] = u.Password row[userTblColIndex_password_last_changed] = u.PasswordLastChanged if u.Locked { - row[userTblColIndex_account_locked] = "Y" + row[userTblColIndex_account_locked] = uint16(2) } if u.Attributes != nil { row[userTblColIndex_User_attributes] = *u.Attributes @@ -172,127 +172,127 @@ func (u *User) rowToPrivSet(ctx *sql.Context, row sql.Row) PrivilegeSet { for i, val := range row { switch i { case userTblColIndex_Select_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Select) } case userTblColIndex_Insert_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Insert) } case userTblColIndex_Update_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Update) } case userTblColIndex_Delete_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Delete) } case userTblColIndex_Create_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Create) } case userTblColIndex_Drop_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Drop) } case userTblColIndex_Reload_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Reload) } case userTblColIndex_Shutdown_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Shutdown) } case userTblColIndex_Process_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Process) } case userTblColIndex_File_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_File) } case userTblColIndex_Grant_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Grant) } case userTblColIndex_References_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_References) } case userTblColIndex_Index_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Index) } case userTblColIndex_Alter_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Alter) } case userTblColIndex_Show_db_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ShowDB) } case userTblColIndex_Super_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Super) } case userTblColIndex_Create_tmp_table_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateTempTable) } case userTblColIndex_Lock_tables_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_LockTables) } case userTblColIndex_Execute_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Execute) } case userTblColIndex_Repl_slave_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ReplicationSlave) } case userTblColIndex_Repl_client_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ReplicationClient) } case userTblColIndex_Create_view_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateView) } case userTblColIndex_Show_view_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ShowView) } case userTblColIndex_Create_routine_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateRoutine) } case userTblColIndex_Alter_routine_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_AlterRoutine) } case userTblColIndex_Create_user_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateUser) } case userTblColIndex_Event_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Event) } case userTblColIndex_Trigger_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Trigger) } case userTblColIndex_Create_tablespace_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateTablespace) } case userTblColIndex_Create_role_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateRole) } case userTblColIndex_Drop_role_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_DropRole) } } @@ -306,67 +306,67 @@ func (u *User) privSetToRow(ctx *sql.Context, row sql.Row) { for _, priv := range u.PrivilegeSet.ToSlice() { switch priv { case sql.PrivilegeType_Select: - row[userTblColIndex_Select_priv] = "Y" + row[userTblColIndex_Select_priv] = uint16(2) case sql.PrivilegeType_Insert: - row[userTblColIndex_Insert_priv] = "Y" + row[userTblColIndex_Insert_priv] = uint16(2) case sql.PrivilegeType_Update: - row[userTblColIndex_Update_priv] = "Y" + row[userTblColIndex_Update_priv] = uint16(2) case sql.PrivilegeType_Delete: - row[userTblColIndex_Delete_priv] = "Y" + row[userTblColIndex_Delete_priv] = uint16(2) case sql.PrivilegeType_Create: - row[userTblColIndex_Create_priv] = "Y" + row[userTblColIndex_Create_priv] = uint16(2) case sql.PrivilegeType_Drop: - row[userTblColIndex_Drop_priv] = "Y" + row[userTblColIndex_Drop_priv] = uint16(2) case sql.PrivilegeType_Reload: - row[userTblColIndex_Reload_priv] = "Y" + row[userTblColIndex_Reload_priv] = uint16(2) case sql.PrivilegeType_Shutdown: - row[userTblColIndex_Shutdown_priv] = "Y" + row[userTblColIndex_Shutdown_priv] = uint16(2) case sql.PrivilegeType_Process: - row[userTblColIndex_Process_priv] = "Y" + row[userTblColIndex_Process_priv] = uint16(2) case sql.PrivilegeType_File: - row[userTblColIndex_File_priv] = "Y" + row[userTblColIndex_File_priv] = uint16(2) case sql.PrivilegeType_Grant: - row[userTblColIndex_Grant_priv] = "Y" + row[userTblColIndex_Grant_priv] = uint16(2) case sql.PrivilegeType_References: - row[userTblColIndex_References_priv] = "Y" + row[userTblColIndex_References_priv] = uint16(2) case sql.PrivilegeType_Index: - row[userTblColIndex_Index_priv] = "Y" + row[userTblColIndex_Index_priv] = uint16(2) case sql.PrivilegeType_Alter: - row[userTblColIndex_Alter_priv] = "Y" + row[userTblColIndex_Alter_priv] = uint16(2) case sql.PrivilegeType_ShowDB: - row[userTblColIndex_Show_db_priv] = "Y" + row[userTblColIndex_Show_db_priv] = uint16(2) case sql.PrivilegeType_Super: - row[userTblColIndex_Super_priv] = "Y" + row[userTblColIndex_Super_priv] = uint16(2) case sql.PrivilegeType_CreateTempTable: - row[userTblColIndex_Create_tmp_table_priv] = "Y" + row[userTblColIndex_Create_tmp_table_priv] = uint16(2) case sql.PrivilegeType_LockTables: - row[userTblColIndex_Lock_tables_priv] = "Y" + row[userTblColIndex_Lock_tables_priv] = uint16(2) case sql.PrivilegeType_Execute: - row[userTblColIndex_Execute_priv] = "Y" + row[userTblColIndex_Execute_priv] = uint16(2) case sql.PrivilegeType_ReplicationSlave: - row[userTblColIndex_Repl_slave_priv] = "Y" + row[userTblColIndex_Repl_slave_priv] = uint16(2) case sql.PrivilegeType_ReplicationClient: - row[userTblColIndex_Repl_client_priv] = "Y" + row[userTblColIndex_Repl_client_priv] = uint16(2) case sql.PrivilegeType_CreateView: - row[userTblColIndex_Create_view_priv] = "Y" + row[userTblColIndex_Create_view_priv] = uint16(2) case sql.PrivilegeType_ShowView: - row[userTblColIndex_Show_view_priv] = "Y" + row[userTblColIndex_Show_view_priv] = uint16(2) case sql.PrivilegeType_CreateRoutine: - row[userTblColIndex_Create_routine_priv] = "Y" + row[userTblColIndex_Create_routine_priv] = uint16(2) case sql.PrivilegeType_AlterRoutine: - row[userTblColIndex_Alter_routine_priv] = "Y" + row[userTblColIndex_Alter_routine_priv] = uint16(2) case sql.PrivilegeType_CreateUser: - row[userTblColIndex_Create_user_priv] = "Y" + row[userTblColIndex_Create_user_priv] = uint16(2) case sql.PrivilegeType_Event: - row[userTblColIndex_Event_priv] = "Y" + row[userTblColIndex_Event_priv] = uint16(2) case sql.PrivilegeType_Trigger: - row[userTblColIndex_Trigger_priv] = "Y" + row[userTblColIndex_Trigger_priv] = uint16(2) case sql.PrivilegeType_CreateTablespace: - row[userTblColIndex_Create_tablespace_priv] = "Y" + row[userTblColIndex_Create_tablespace_priv] = uint16(2) case sql.PrivilegeType_CreateRole: - row[userTblColIndex_Create_role_priv] = "Y" + row[userTblColIndex_Create_role_priv] = uint16(2) case sql.PrivilegeType_DropRole: - row[userTblColIndex_Drop_role_priv] = "Y" + row[userTblColIndex_Drop_role_priv] = uint16(2) } } } diff --git a/sql/nulltype.go b/sql/nulltype.go index 0af99c9e38..a37226eb42 100644 --- a/sql/nulltype.go +++ b/sql/nulltype.go @@ -15,6 +15,8 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -83,6 +85,11 @@ func (t nullType) Type() query.Type { return sqltypes.Null } +// ValueType implements Type interface. +func (t nullType) ValueType() reflect.Type { + return nil +} + // Zero implements Type interface. func (t nullType) Zero() interface{} { return nil diff --git a/sql/numbertype.go b/sql/numbertype.go index 64de6b91e5..f162cd3f19 100644 --- a/sql/numbertype.go +++ b/sql/numbertype.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "math" + "reflect" "strconv" "time" @@ -74,11 +75,23 @@ var ( dec_int64_min = decimal.NewFromInt(math.MinInt64) // decimal that represents the zero value dec_zero = decimal.NewFromInt(0) + + numberInt8ValueType = reflect.TypeOf(int8(0)) + numberInt16ValueType = reflect.TypeOf(int16(0)) + numberInt32ValueType = reflect.TypeOf(int32(0)) + numberInt64ValueType = reflect.TypeOf(int64(0)) + numberUint8ValueType = reflect.TypeOf(uint8(0)) + numberUint16ValueType = reflect.TypeOf(uint16(0)) + numberUint32ValueType = reflect.TypeOf(uint32(0)) + numberUint64ValueType = reflect.TypeOf(uint64(0)) + numberFloat32ValueType = reflect.TypeOf(float32(0)) + numberFloat64ValueType = reflect.TypeOf(float64(0)) ) -// Represents all integer and floating point types. +// NumberType represents all integer and floating point types. // https://dev.mysql.com/doc/refman/8.0/en/integer-types.html // https://dev.mysql.com/doc/refman/8.0/en/floating-point-types.html +// The type of the returned value is one of the following: int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64. type NumberType interface { Type IsSigned() bool @@ -358,9 +371,9 @@ func (t numberTypeImpl) SQL(dest []byte, v interface{}) (sqltypes.Value, error) case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: dest = strconv.AppendUint(dest, mustUint64(v), 10) case sqltypes.Float32: - dest = strconv.AppendFloat(dest, float64(v.(float32)), 'f', -1, 32) + dest = strconv.AppendFloat(dest, float64(v.(float32)), 'g', -1, 32) case sqltypes.Float64: - dest = strconv.AppendFloat(dest, v.(float64), 'f', -1, 64) + dest = strconv.AppendFloat(dest, v.(float64), 'g', -1, 64) default: panic(ErrInvalidBaseType.New(t.baseType.String(), "number")) } @@ -597,6 +610,38 @@ func (t numberTypeImpl) Type() query.Type { return t.baseType } +// ValueType implements Type interface. +func (t numberTypeImpl) ValueType() reflect.Type { + switch t.baseType { + case sqltypes.Int8: + return numberInt8ValueType + case sqltypes.Uint8: + return numberUint8ValueType + case sqltypes.Int16: + return numberInt16ValueType + case sqltypes.Uint16: + return numberUint16ValueType + case sqltypes.Int24: + return numberInt32ValueType + case sqltypes.Uint24: + return numberUint32ValueType + case sqltypes.Int32: + return numberInt32ValueType + case sqltypes.Uint32: + return numberUint32ValueType + case sqltypes.Int64: + return numberInt64ValueType + case sqltypes.Uint64: + return numberUint64ValueType + case sqltypes.Float32: + return numberFloat32ValueType + case sqltypes.Float64: + return numberFloat64ValueType + default: + panic(fmt.Sprintf("%v is not a valid number base type", t.baseType.String())) + } +} + // Zero implements Type interface. func (t numberTypeImpl) Zero() interface{} { switch t.baseType { diff --git a/sql/numbertype_test.go b/sql/numbertype_test.go index 925b1bf9fd..297ff1b912 100644 --- a/sql/numbertype_test.go +++ b/sql/numbertype_test.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "math" + "reflect" "strconv" "testing" "time" @@ -227,6 +228,9 @@ func TestNumberConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/plan/common_test.go b/sql/plan/common_test.go index 1c2327bdc0..6eb470adf1 100644 --- a/sql/plan/common_test.go +++ b/sql/plan/common_test.go @@ -39,15 +39,19 @@ var benchtable = func() *memory.Table { for i := 0; i < 100; i++ { n := fmt.Sprint(i) + boolVal := int8(0) + if i%2 == 0 { + boolVal = 1 + } err := t.Insert( sql.NewEmptyContext(), sql.NewRow( repeatStr(n, i%10+1), float64(i), - i%2 == 0, + boolVal, int32(i), int64(i), - []byte(repeatStr(n, 100+(i%100))), + repeatStr(n, 100+(i%100)), ), ) if err != nil { @@ -60,10 +64,10 @@ var benchtable = func() *memory.Table { sql.NewRow( repeatStr(n, i%10+1), float64(i), - i%2 == 0, + boolVal, int32(i), int64(i), - []byte(repeatStr(n, 100+(i%100))), + repeatStr(n, 100+(i%100)), ), ) if err != nil { diff --git a/sql/plan/external_procedure.go b/sql/plan/external_procedure.go index c13a66768e..750d281b77 100644 --- a/sql/plan/external_procedure.go +++ b/sql/plan/external_procedure.go @@ -183,11 +183,7 @@ func (n *ExternalProcedure) processParam(ctx *sql.Context, funcParamType reflect exprParamVal = int(exprParamVal.(uint64)) } case decimalType: - var err error - exprParamVal, err = decimal.NewFromString(exprParamVal.(string)) - if err != nil { - return reflect.Value{}, err - } + exprParamVal = exprParamVal.(decimal.Decimal) } if funcParamType.Kind() == reflect.Ptr { // Coincides with INOUT diff --git a/sql/plan/sort_test.go b/sql/plan/sort_test.go index 15e03980bf..1fcf94481a 100644 --- a/sql/plan/sort_test.go +++ b/sql/plan/sort_test.go @@ -42,9 +42,9 @@ func TestSort(t *testing.T) { { rows: []sql.Row{ sql.NewRow("c", nil, nil), - sql.NewRow("a", int32(3), 3.0), - sql.NewRow("b", int32(3), 3.0), - sql.NewRow("c", int32(1), 1.0), + sql.NewRow("a", int32(3), float64(3.0)), + sql.NewRow("b", int32(3), float64(3.0)), + sql.NewRow("c", int32(1), float64(1.0)), sql.NewRow(nil, int32(1), nil), }, sortFields: []sql.SortField{ @@ -55,14 +55,14 @@ func TestSort(t *testing.T) { expected: []sql.Row{ sql.NewRow("c", nil, nil), sql.NewRow(nil, int32(1), nil), - sql.NewRow("c", int32(1), 1.0), - sql.NewRow("b", int32(3), 3.0), - sql.NewRow("a", int32(3), 3.0), + sql.NewRow("c", int32(1), float64(1.0)), + sql.NewRow("b", int32(3), float64(3.0)), + sql.NewRow("a", int32(3), float64(3.0)), }, }, { rows: []sql.Row{ - sql.NewRow("c", int32(3), 3.0), + sql.NewRow("c", int32(3), float64(3.0)), sql.NewRow("c", int32(3), nil), }, sortFields: []sql.SortField{ @@ -72,15 +72,15 @@ func TestSort(t *testing.T) { }, expected: []sql.Row{ sql.NewRow("c", int32(3), nil), - sql.NewRow("c", int32(3), 3.0), + sql.NewRow("c", int32(3), float64(3.0)), }, }, { rows: []sql.Row{ sql.NewRow("c", nil, nil), - sql.NewRow("a", int32(3), 3.0), - sql.NewRow("b", int32(3), 3.0), - sql.NewRow("c", int32(1), 1.0), + sql.NewRow("a", int32(3), float64(3.0)), + sql.NewRow("b", int32(3), float64(3.0)), + sql.NewRow("c", int32(1), float64(1.0)), sql.NewRow(nil, int32(1), nil), }, sortFields: []sql.SortField{ @@ -91,15 +91,15 @@ func TestSort(t *testing.T) { expected: []sql.Row{ sql.NewRow("c", nil, nil), sql.NewRow(nil, int32(1), nil), - sql.NewRow("c", int32(1), 1.0), - sql.NewRow("a", int32(3), 3.0), - sql.NewRow("b", int32(3), 3.0), + sql.NewRow("c", int32(1), float64(1.0)), + sql.NewRow("a", int32(3), float64(3.0)), + sql.NewRow("b", int32(3), float64(3.0)), }, }, { rows: []sql.Row{ - sql.NewRow("a", int32(1), 2), - sql.NewRow("a", int32(1), 1), + sql.NewRow("a", int32(1), float64(2)), + sql.NewRow("a", int32(1), float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, @@ -107,18 +107,18 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow("a", int32(1), 1), - sql.NewRow("a", int32(1), 2), + sql.NewRow("a", int32(1), float64(1)), + sql.NewRow("a", int32(1), float64(2)), }, }, { rows: []sql.Row{ - sql.NewRow("a", int32(1), 2), - sql.NewRow("a", int32(1), 1), - sql.NewRow("a", int32(2), 2), - sql.NewRow("a", int32(3), 1), - sql.NewRow("b", int32(2), 2), - sql.NewRow("c", int32(3), 1), + sql.NewRow("a", int32(1), float64(2)), + sql.NewRow("a", int32(1), float64(1)), + sql.NewRow("a", int32(2), float64(2)), + sql.NewRow("a", int32(3), float64(1)), + sql.NewRow("b", int32(2), float64(2)), + sql.NewRow("c", int32(3), float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, @@ -126,18 +126,18 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow("a", int32(3), 1), - sql.NewRow("a", int32(2), 2), - sql.NewRow("a", int32(1), 1), - sql.NewRow("a", int32(1), 2), - sql.NewRow("b", int32(2), 2), - sql.NewRow("c", int32(3), 1), + sql.NewRow("a", int32(3), float64(1)), + sql.NewRow("a", int32(2), float64(2)), + sql.NewRow("a", int32(1), float64(1)), + sql.NewRow("a", int32(1), float64(2)), + sql.NewRow("b", int32(2), float64(2)), + sql.NewRow("c", int32(3), float64(1)), }, }, { rows: []sql.Row{ - sql.NewRow(nil, nil, 2), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(2)), + sql.NewRow(nil, nil, float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, @@ -145,14 +145,14 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow(nil, nil, 1), - sql.NewRow(nil, nil, 2), + sql.NewRow(nil, nil, float64(1)), + sql.NewRow(nil, nil, float64(2)), }, }, { rows: []sql.Row{ - sql.NewRow(nil, nil, 1), - sql.NewRow(nil, nil, 2), + sql.NewRow(nil, nil, float64(1)), + sql.NewRow(nil, nil, float64(2)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Descending, NullOrdering: sql.NullsFirst}, @@ -160,13 +160,13 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Descending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow(nil, nil, 2), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(2)), + sql.NewRow(nil, nil, float64(1)), }, }, { rows: []sql.Row{ - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), sql.NewRow(nil, nil, nil), }, sortFields: []sql.SortField{ @@ -176,13 +176,13 @@ func TestSort(t *testing.T) { }, expected: []sql.Row{ sql.NewRow(nil, nil, nil), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), }, }, { rows: []sql.Row{ sql.NewRow(nil, nil, nil), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsLast}, @@ -190,7 +190,7 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsLast}, }, expected: []sql.Row{ - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), sql.NewRow(nil, nil, nil), }, }, diff --git a/sql/point.go b/sql/point.go index b5b832bed4..b6f19e125f 100644 --- a/sql/point.go +++ b/sql/point.go @@ -15,28 +15,37 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" ) -// Represents the Point type. +// PointType represents the POINT type. // https://dev.mysql.com/doc/refman/8.0/en/gis-class-point.html +// The type of the returned value is Point. +type PointType struct { + SRID uint32 + DefinedSRID bool +} + +// Point is the value type returned from PointType. Implements GeometryValue. type Point struct { SRID uint32 X float64 Y float64 } -type PointType struct { - SRID uint32 - DefinedSRID bool -} - var _ Type = PointType{} var _ SpatialColumnType = PointType{} +var _ GeometryValue = Point{} -var ErrNotPoint = errors.NewKind("value of type %T is not a point") +var ( + ErrNotPoint = errors.NewKind("value of type %T is not a point") + + pointValueType = reflect.TypeOf(Point{}) +) // Compare implements Type interface. func (t PointType) Compare(a interface{}, b interface{}) (int, error) { @@ -133,7 +142,10 @@ func (t PointType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, nil } - return sqltypes.MakeTrusted(sqltypes.Geometry, []byte(pv.(string))), nil + //TODO: pretty sure this is wrong, pv is not a string type + val := appendAndSliceString(dest, pv.(string)) + + return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } // String implements Type interface. @@ -151,6 +163,11 @@ func (t PointType) Zero() interface{} { return Point{X: 0.0, Y: 0.0} } +// ValueType implements Type interface. +func (t PointType) ValueType() reflect.Type { + return pointValueType +} + // GetSpatialTypeSRID implements SpatialColumnType interface. func (t PointType) GetSpatialTypeSRID() (uint32, bool) { return t.SRID, t.DefinedSRID @@ -176,3 +193,6 @@ func (t PointType) MatchSRID(v interface{}) error { } return ErrNotMatchingSRID.New(val.SRID, t.SRID) } + +// implementsGeometryValue implements GeometryValue interface. +func (p Point) implementsGeometryValue() {} diff --git a/sql/polygon.go b/sql/polygon.go index 01b32259b8..b4055b7cf2 100644 --- a/sql/polygon.go +++ b/sql/polygon.go @@ -15,28 +15,37 @@ package sql import ( + "reflect" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) -// Represents the Polygon type. +// PolygonType represents the POLYGON type. // https://dev.mysql.com/doc/refman/8.0/en/gis-class-polygon.html -type Polygon struct { - SRID uint32 - Lines []LineString -} - +// The type of the returned value is Polygon. type PolygonType struct { SRID uint32 DefinedSRID bool } +// Polygon is the value type returned from PolygonType. Implements GeometryValue. +type Polygon struct { + SRID uint32 + Lines []LineString +} + var _ Type = PolygonType{} var _ SpatialColumnType = PolygonType{} +var _ GeometryValue = Polygon{} -var ErrNotPolygon = errors.NewKind("value of type %T is not a polygon") +var ( + ErrNotPolygon = errors.NewKind("value of type %T is not a polygon") + + polygonValueType = reflect.TypeOf(Polygon{}) +) // Compare implements Type interface. func (t PolygonType) Compare(a interface{}, b interface{}) (int, error) { @@ -146,7 +155,10 @@ func (t PolygonType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, nil } - return sqltypes.MakeTrusted(sqltypes.Geometry, []byte(lv.(string))), nil + //TODO: pretty sure this is wrong, lv is not a string type + val := appendAndSliceString(dest, lv.(string)) + + return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } // String implements Type interface. @@ -159,6 +171,11 @@ func (t PolygonType) Type() query.Type { return sqltypes.Geometry } +// ValueType implements Type interface. +func (t PolygonType) ValueType() reflect.Type { + return polygonValueType +} + // Zero implements Type interface. func (t PolygonType) Zero() interface{} { return Polygon{Lines: []LineString{{Points: []Point{{}, {}, {}, {}}}}} @@ -189,3 +206,6 @@ func (t PolygonType) MatchSRID(v interface{}) error { } return ErrNotMatchingSRID.New(val.SRID, t.SRID) } + +// implementsGeometryValue implements GeometryValue interface. +func (p Polygon) implementsGeometryValue() {} diff --git a/sql/settype.go b/sql/settype.go index 4def914722..64fb02777f 100644 --- a/sql/settype.go +++ b/sql/settype.go @@ -18,9 +18,12 @@ import ( "fmt" "math" "math/bits" + "reflect" "strconv" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -36,20 +39,24 @@ var ( ErrDuplicateEntrySet = errors.NewKind("duplicate entry: %v") ErrInvalidSetValue = errors.NewKind("value %v was not found in the set") ErrTooLargeForSet = errors.NewKind(`value "%v" is too large for this set`) + + setValueType = reflect.TypeOf(uint64(0)) ) // Comments with three slashes were taken directly from the linked documentation. -// Represents the SET type. +// SetType represents the SET type. // https://dev.mysql.com/doc/refman/8.0/en/set.html +// The type of the returned value is uint64. type SetType interface { Type CharacterSet() CharacterSet Collation() Collation - //TODO: move this out of go-mysql-server and into the Dolt layer - Marshal(v interface{}) (uint64, error) + // NumberOfElements returns the number of elements in this set. NumberOfElements() uint16 - Unmarshal(bits uint64) (string, error) + // BitsToString takes a previously-converted value and returns it as a string. + BitsToString(bits uint64) (string, error) + // Values returns all of the set's values in ascending order according to their corresponding bit value. Values() []string } @@ -121,18 +128,20 @@ func (t setType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - ai, err := t.Marshal(a) + ai, err := t.Convert(a) if err != nil { return 0, err } - bi, err := t.Marshal(b) + bi, err := t.Convert(b) if err != nil { return 0, err } + au := ai.(uint64) + bu := bi.(uint64) - if ai < bi { + if au < bu { return -1, nil - } else if ai > bi { + } else if au > bu { return 1, nil } return 0, nil @@ -166,26 +175,26 @@ func (t setType) Convert(v interface{}) (interface{}, error) { return t.Convert(uint64(value)) case uint64: if value <= t.allValuesBitField() { - return t.convertBitFieldToString(value) + return value, nil } - return nil, ErrConvertingToSet.New(v) case float32: return t.Convert(uint64(value)) case float64: return t.Convert(uint64(value)) - case string: - // For SET('a','b') and given a string 'b,a,a', we would return 'a,b', so we can't return the input. - bitField, err := t.convertStringToBitField(value) - if err != nil { - return nil, err + case decimal.Decimal: + return t.Convert(value.BigInt().Uint64()) + case decimal.NullDecimal: + if !value.Valid { + return nil, nil } - setStr, _ := t.convertBitFieldToString(bitField) - return setStr, nil + return t.Convert(value.Decimal.BigInt().Uint64()) + case string: + return t.convertStringToBitField(value) case []byte: return t.Convert(string(value)) } - return nil, ErrConvertingToSet.New(v) + return uint64(0), ErrConvertingToSet.New(v) } // MustConvert implements the Type interface. @@ -220,12 +229,16 @@ func (t setType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } - value, err := t.Convert(v) + convertedValue, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + value, err := t.BitsToString(convertedValue.(uint64)) if err != nil { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(value.(string))) + val := appendAndSliceString(dest, value) return sqltypes.MakeTrusted(sqltypes.Set, val), nil } @@ -247,68 +260,37 @@ func (t setType) Type() query.Type { return sqltypes.Set } +// ValueType implements Type interface. +func (t setType) ValueType() reflect.Type { + return setValueType +} + // Zero implements Type interface. func (t setType) Zero() interface{} { return "" } +// CharacterSet implements EnumType interface. func (t setType) CharacterSet() CharacterSet { return t.collation.CharacterSet() } +// Collation implements EnumType interface. func (t setType) Collation() Collation { return t.collation } -// Marshal takes a valid Set value and returns it as an uint64. -func (t setType) Marshal(v interface{}) (uint64, error) { - switch value := v.(type) { - case int: - return t.Marshal(uint64(value)) - case uint: - return t.Marshal(uint64(value)) - case int8: - return t.Marshal(uint64(value)) - case uint8: - return t.Marshal(uint64(value)) - case int16: - return t.Marshal(uint64(value)) - case uint16: - return t.Marshal(uint64(value)) - case int32: - return t.Marshal(uint64(value)) - case uint32: - return t.Marshal(uint64(value)) - case int64: - return t.Marshal(uint64(value)) - case uint64: - if value <= t.allValuesBitField() { - return value, nil - } - case float32: - return t.Marshal(uint64(value)) - case float64: - return t.Marshal(uint64(value)) - case string: - return t.convertStringToBitField(value) - case []byte: - return t.Marshal(string(value)) - } - - return uint64(0), ErrConvertingToSet.New(v) -} - -// NumberOfElements returns the number of elements in this set. +// NumberOfElements implements EnumType interface. func (t setType) NumberOfElements() uint16 { return uint16(len(t.valToBit)) } -// Unmarshal takes a previously-marshalled value and returns it as a string. -func (t setType) Unmarshal(v uint64) (string, error) { +// BitsToString implements EnumType interface. +func (t setType) BitsToString(v uint64) (string, error) { return t.convertBitFieldToString(v) } -// Values returns all of the set's values in ascending order according to their corresponding bit value. +// Values implements EnumType interface. func (t setType) Values() []string { bitEdge := 64 - bits.LeadingZeros64(t.allValuesBitField()) valArray := make([]string, bitEdge) diff --git a/sql/settype_test.go b/sql/settype_test.go index a2200ba2de..00d6f0c69c 100644 --- a/sql/settype_test.go +++ b/sql/settype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strconv" "testing" "time" @@ -185,7 +186,12 @@ func TestSetConvert(t *testing.T) { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) + res, err := typ.Compare(test.expectedVal, val) + require.NoError(t, err) + assert.Equal(t, 0, res) + if val != nil { + assert.Equal(t, typ.ValueType(), reflect.TypeOf(val)) + } } }) } @@ -208,12 +214,14 @@ func TestSetMarshalMax(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%v", test), func(t *testing.T) { - bits, err := typ.Marshal(test) + bits, err := typ.Convert(test) require.NoError(t, err) - res1, err := typ.Unmarshal(bits) + res1, err := typ.BitsToString(bits.(uint64)) require.NoError(t, err) require.Equal(t, test, res1) - res2, err := typ.Convert(bits) + bits2, err := typ.Convert(bits) + require.NoError(t, err) + res2, err := typ.BitsToString(bits2.(uint64)) require.NoError(t, err) require.Equal(t, test, res2) }) diff --git a/sql/stringtype.go b/sql/stringtype.go index f3b5cff5fd..2a07056530 100644 --- a/sql/stringtype.go +++ b/sql/stringtype.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strconv" "strings" "time" @@ -53,12 +54,15 @@ var ( Blob = MustCreateBinary(sqltypes.Blob, textBlobMax) MediumBlob = MustCreateBinary(sqltypes.Blob, mediumTextBlobMax) LongBlob = MustCreateBinary(sqltypes.Blob, longTextBlobMax) + + stringValueType = reflect.TypeOf(string("")) ) // StringType represents all string types, including VARCHAR and BLOB. // https://dev.mysql.com/doc/refman/8.0/en/char.html // https://dev.mysql.com/doc/refman/8.0/en/binary-varbinary.html // https://dev.mysql.com/doc/refman/8.0/en/blob.html +// The type of the returned value is string. type StringType interface { Type CharacterSet() CharacterSet @@ -349,7 +353,7 @@ func (t stringType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, v.(string)) return sqltypes.MakeTrusted(t.baseType, val), nil } @@ -407,6 +411,11 @@ func (t stringType) Type() query.Type { return t.baseType } +// ValueType implements Type interface. +func (t stringType) ValueType() reflect.Type { + return stringValueType +} + // Zero implements Type interface. func (t stringType) Zero() interface{} { return "" @@ -442,7 +451,14 @@ func (t stringType) CreateMatcher(likeStr string) (regex.DisposableMatcher, erro } } -func appendAndSlice(buffer, addition []byte) (slice []byte) { +func appendAndSliceString(buffer []byte, addition string) (slice []byte) { + stop := len(buffer) + buffer = append(buffer, addition...) + slice = buffer[stop:] + return +} + +func appendAndSliceBytes(buffer, addition []byte) (slice []byte) { stop := len(buffer) buffer = append(buffer, addition...) slice = buffer[stop:] diff --git a/sql/stringtype_test.go b/sql/stringtype_test.go index 9b1b7a03ca..b4b35214a2 100644 --- a/sql/stringtype_test.go +++ b/sql/stringtype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strings" "testing" "time" @@ -327,6 +328,9 @@ func TestStringConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/system_booltype.go b/sql/system_booltype.go index 96e5f872a9..3f6f5a43f4 100644 --- a/sql/system_booltype.go +++ b/sql/system_booltype.go @@ -15,13 +15,18 @@ package sql import ( + "reflect" "strconv" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemBoolValueType = reflect.TypeOf(int8(0)) + // systemBoolType is an internal boolean type ONLY for system variables. type systemBoolType struct { varName string @@ -95,6 +100,14 @@ func (t systemBoolType) Convert(v interface{}) (interface{}, error) { if value == float64(int64(value)) { return t.Convert(int64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } case string: switch strings.ToLower(value) { case "on", "true": @@ -157,6 +170,11 @@ func (t systemBoolType) Type() query.Type { return sqltypes.Int8 } +// ValueType implements Type interface. +func (t systemBoolType) ValueType() reflect.Type { + return systemBoolValueType +} + // Zero implements Type interface. func (t systemBoolType) Zero() interface{} { return int8(0) diff --git a/sql/system_doubletype.go b/sql/system_doubletype.go index d038d64090..5a7849f188 100644 --- a/sql/system_doubletype.go +++ b/sql/system_doubletype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemDoubleValueType = reflect.TypeOf(float64(0)) + // systemDoubleType is an internal double type ONLY for system variables. type systemDoubleType struct { varName string @@ -87,6 +92,14 @@ func (t systemDoubleType) Convert(v interface{}) (interface{}, error) { if value >= t.lowerbound && value <= t.upperbound { return value, nil } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } } return nil, ErrInvalidSystemVariableValue.New(t.varName, v) @@ -142,6 +155,11 @@ func (t systemDoubleType) Type() query.Type { return sqltypes.Float64 } +// ValueType implements Type interface. +func (t systemDoubleType) ValueType() reflect.Type { + return systemDoubleValueType +} + // Zero implements Type interface. func (t systemDoubleType) Zero() interface{} { return float64(0) diff --git a/sql/system_enumtype.go b/sql/system_enumtype.go index 90c8c2646f..d219b861c1 100644 --- a/sql/system_enumtype.go +++ b/sql/system_enumtype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemEnumValueType = reflect.TypeOf(string("")) + // systemEnumType is an internal enum type ONLY for system variables. type systemEnumType struct { varName string @@ -98,6 +103,14 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) { if value == float64(int(value)) { return t.Convert(int(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } case string: if idx, ok := t.valToIndex[strings.ToLower(value)]; ok { return t.indexToVal[idx], nil @@ -145,7 +158,7 @@ func (t systemEnumType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, v.(string)) return sqltypes.MakeTrusted(t.Type(), val), nil } @@ -160,6 +173,11 @@ func (t systemEnumType) Type() query.Type { return sqltypes.VarChar } +// ValueType implements Type interface. +func (t systemEnumType) ValueType() reflect.Type { + return systemEnumValueType +} + // Zero implements Type interface. func (t systemEnumType) Zero() interface{} { return "" diff --git a/sql/system_inttype.go b/sql/system_inttype.go index 7696012f3b..0b2ccb3556 100644 --- a/sql/system_inttype.go +++ b/sql/system_inttype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemIntValueType = reflect.TypeOf(int64(0)) + // systemIntType is an internal integer type ONLY for system variables. type systemIntType struct { varName string @@ -95,6 +100,14 @@ func (t systemIntType) Convert(v interface{}) (interface{}, error) { if value == float64(int64(value)) { return t.Convert(int64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } } return nil, ErrInvalidSystemVariableValue.New(t.varName, v) @@ -150,6 +163,11 @@ func (t systemIntType) Type() query.Type { return sqltypes.Int64 } +// ValueType implements Type interface. +func (t systemIntType) ValueType() reflect.Type { + return systemIntValueType +} + // Zero implements Type interface. func (t systemIntType) Zero() interface{} { return int64(0) diff --git a/sql/system_settype.go b/sql/system_settype.go index 749dfd10d2..30150f2e1e 100644 --- a/sql/system_settype.go +++ b/sql/system_settype.go @@ -15,8 +15,11 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" ) // systemSetType is an internal set type ONLY for system variables. @@ -37,19 +40,21 @@ func (t systemSetType) Compare(a interface{}, b interface{}) (int, error) { if a == nil || b == nil { return 0, ErrInvalidSystemVariableValue.New(t.varName, nil) } - ai, err := t.Marshal(a) + ai, err := t.Convert(a) if err != nil { return 0, err } - bi, err := t.Marshal(b) + bi, err := t.Convert(b) if err != nil { return 0, err } + au := ai.(uint64) + bu := bi.(uint64) - if ai == bi { + if au == bu { return 0, nil } - if ai < bi { + if au < bu { return -1, nil } return 1, nil @@ -87,6 +92,14 @@ func (t systemSetType) Convert(v interface{}) (interface{}, error) { if value == float64(int64(value)) { return t.SetType.Convert(int64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } case string: return t.SetType.Convert(value) } @@ -121,13 +134,16 @@ func (t systemSetType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } - - v, err := t.Convert(v) + convertedValue, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + value, err := t.BitsToString(convertedValue.(uint64)) if err != nil { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, value) return sqltypes.MakeTrusted(t.Type(), val), nil } @@ -142,6 +158,11 @@ func (t systemSetType) Type() query.Type { return sqltypes.VarChar } +// ValueType implements Type interface. +func (t systemSetType) ValueType() reflect.Type { + return t.SetType.ValueType() +} + // Zero implements Type interface. func (t systemSetType) Zero() interface{} { return "" @@ -149,11 +170,11 @@ func (t systemSetType) Zero() interface{} { // EncodeValue implements SystemVariableType interface. func (t systemSetType) EncodeValue(val interface{}) (string, error) { - expectedVal, ok := val.(string) + expectedVal, ok := val.(uint64) if !ok { return "", ErrSystemVariableCodeFail.New(val, t.String()) } - return expectedVal, nil + return t.BitsToString(expectedVal) } // DecodeValue implements SystemVariableType interface. diff --git a/sql/system_stringtype.go b/sql/system_stringtype.go index 5337da08ec..ddcdd13e65 100644 --- a/sql/system_stringtype.go +++ b/sql/system_stringtype.go @@ -15,10 +15,14 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemStringValueType = reflect.TypeOf(string("")) + // systemStringType is an internal string type ONLY for system variables. type systemStringType struct { varName string @@ -98,7 +102,7 @@ func (t systemStringType) SQL(dest []byte, v interface{}) (sqltypes.Value, error return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, v.(string)) return sqltypes.MakeTrusted(t.Type(), val), nil } @@ -113,6 +117,11 @@ func (t systemStringType) Type() query.Type { return sqltypes.VarChar } +// ValueType implements Type interface. +func (t systemStringType) ValueType() reflect.Type { + return systemStringValueType +} + // Zero implements Type interface. func (t systemStringType) Zero() interface{} { return "" diff --git a/sql/system_uinttype.go b/sql/system_uinttype.go index 2ec670fb6f..5ff304db9f 100644 --- a/sql/system_uinttype.go +++ b/sql/system_uinttype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemUintValueType = reflect.TypeOf(uint64(0)) + // systemUintType is an internal unsigned integer type ONLY for system variables. type systemUintType struct { varName string @@ -91,6 +96,14 @@ func (t systemUintType) Convert(v interface{}) (interface{}, error) { if value == float64(uint64(value)) { return t.Convert(uint64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } } return nil, ErrInvalidSystemVariableValue.New(t.varName, v) @@ -146,6 +159,11 @@ func (t systemUintType) Type() query.Type { return sqltypes.Uint64 } +// ValueType implements Type interface. +func (t systemUintType) ValueType() reflect.Type { + return systemUintValueType +} + // Zero implements Type interface. func (t systemUintType) Zero() interface{} { return uint64(0) diff --git a/sql/timetype.go b/sql/timetype.go index eff9a83433..f0bed3adad 100644 --- a/sql/timetype.go +++ b/sql/timetype.go @@ -17,10 +17,13 @@ package sql import ( "fmt" "math" + "reflect" "strconv" "strings" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -37,28 +40,35 @@ var ( microsecondsPerMinute int64 = 60000000 microsecondsPerHour int64 = 3600000000 nanosecondsPerMicrosecond int64 = 1000 + + timeValueType = reflect.TypeOf(Timespan(0)) ) -// Represents the TIME type. +// TimeType represents the TIME type. // https://dev.mysql.com/doc/refman/8.0/en/time.html -// TIME is implemented as TIME(6) +// TIME is implemented as TIME(6). +// The type of the returned value is Timespan. // TODO: implement parameters on the TIME type type TimeType interface { Type + // ConvertToTimespan returns a Timespan from the given interface. Follows the same conversion rules as + // Convert(), in that this will process the value based on its base-10 visual representation (for example, Convert() + // will interpret the value `1234` as 12 minutes and 34 seconds). Returns an error for nil values. + ConvertToTimespan(v interface{}) (Timespan, error) + // ConvertToTimeDuration returns a time.Duration from the given interface. Follows the same conversion rules as + // Convert(), in that this will process the value based on its base-10 visual representation (for example, Convert() + // will interpret the value `1234` as 12 minutes and 34 seconds). Returns an error for nil values. ConvertToTimeDuration(v interface{}) (time.Duration, error) - //TODO: move this out of go-mysql-server and into the Dolt layer - Marshal(v interface{}) (int64, error) - Unmarshal(v int64) string + // MicrosecondsToTimespan returns a Timespan from the given number of microseconds. This differs from Convert(), as + // that will process the value based on its base-10 visual representation (for example, Convert() will interpret + // the value `1234` as 12 minutes and 34 seconds). This clamps the given microseconds to the allowed range. + MicrosecondsToTimespan(v int64) Timespan } type timespanType struct{} -type timespanImpl struct { - negative bool - hours int16 - minutes int8 - seconds int8 - microseconds int32 -} + +// Timespan is the value type returned by TimeType.Convert(). +type Timespan int64 // Compare implements Type interface. func (t timespanType) Compare(a interface{}, b interface{}) (int, error) { @@ -66,24 +76,16 @@ func (t timespanType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - as, err := t.ConvertToTimespanImpl(a) + as, err := t.ConvertToTimespan(a) if err != nil { return 0, err } - bs, err := t.ConvertToTimespanImpl(b) + bs, err := t.ConvertToTimespan(b) if err != nil { return 0, err } - ai := as.AsMicroseconds() - bi := bs.AsMicroseconds() - - if ai < bi { - return -1, nil - } else if ai > bi { - return 1, nil - } - return 0, nil + return as.Compare(bs), nil } func (t timespanType) Convert(v interface{}) (interface{}, error) { @@ -91,11 +93,7 @@ func (t timespanType) Convert(v interface{}) (interface{}, error) { return nil, nil } - if ti, err := t.ConvertToTimespanImpl(v); err != nil { - return nil, err - } else { - return ti.String(), nil - } + return t.ConvertToTimespan(v) } // MustConvert implements the Type interface. @@ -107,38 +105,44 @@ func (t timespanType) MustConvert(v interface{}) interface{} { return value } -// Convert implements Type interface. -func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) { +// ConvertToTimespan converts the given interface value to a Timespan. This follows the conversion rules of MySQL, which +// are based on the base-10 visual representation of numbers (for example, Time.Convert() will interpret the value +// `1234` as 12 minutes and 34 seconds). Returns an error on a nil value. +func (t timespanType) ConvertToTimespan(v interface{}) (Timespan, error) { switch value := v.(type) { + case Timespan: + // We only create a Timespan if it's valid, so we can skip this check if we receive a Timespan. + // Timespan values are not intended to be modified by an integrator, therefore it is on the integrator if they corrupt a Timespan. + return value, nil case int: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int8: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint8: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int16: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint16: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int32: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint32: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int64: absValue := int64Abs(value) if absValue >= -59 && absValue <= 59 { - return microsecondsToTimespan(value * microsecondsPerSecond), nil + return t.MicrosecondsToTimespan(value * microsecondsPerSecond), nil } else if absValue >= 100 && absValue <= 9999 { minutes := absValue / 100 seconds := absValue % 100 if minutes <= 59 && seconds <= 59 { microseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) if value < 0 { - return microsecondsToTimespan(-1 * microseconds), nil + return t.MicrosecondsToTimespan(-1 * microseconds), nil } - return microsecondsToTimespan(microseconds), nil + return t.MicrosecondsToTimespan(microseconds), nil } } else if absValue >= 10000 && absValue <= 9999999 { hours := absValue / 10000 @@ -147,15 +151,15 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if minutes <= 59 && seconds <= 59 { microseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) + (hours * microsecondsPerHour) if value < 0 { - return microsecondsToTimespan(-1 * microseconds), nil + return t.MicrosecondsToTimespan(-1 * microseconds), nil } - return microsecondsToTimespan(microseconds), nil + return t.MicrosecondsToTimespan(microseconds), nil } } case uint64: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case float32: - return t.ConvertToTimespanImpl(float64(value)) + return t.ConvertToTimespan(float64(value)) case float64: intValue := int64(value) microseconds := int64Abs(int64(math.Round((value - float64(intValue)) * float64(microsecondsPerSecond)))) @@ -163,18 +167,18 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if absValue >= -59 && absValue <= 59 { totalMicroseconds := (absValue * microsecondsPerSecond) + microseconds if value < 0 { - return microsecondsToTimespan(-1 * totalMicroseconds), nil + return t.MicrosecondsToTimespan(-1 * totalMicroseconds), nil } - return microsecondsToTimespan(totalMicroseconds), nil + return t.MicrosecondsToTimespan(totalMicroseconds), nil } else if absValue >= 100 && absValue <= 9999 { minutes := absValue / 100 seconds := absValue % 100 if minutes <= 59 && seconds <= 59 { totalMicroseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) + microseconds if value < 0 { - return microsecondsToTimespan(-1 * totalMicroseconds), nil + return t.MicrosecondsToTimespan(-1 * totalMicroseconds), nil } - return microsecondsToTimespan(totalMicroseconds), nil + return t.MicrosecondsToTimespan(totalMicroseconds), nil } } else if absValue >= 10000 && absValue <= 9999999 { hours := absValue / 10000 @@ -183,11 +187,17 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if minutes <= 59 && seconds <= 59 { totalMicroseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) + (hours * microsecondsPerHour) + microseconds if value < 0 { - return microsecondsToTimespan(-1 * totalMicroseconds), nil + return t.MicrosecondsToTimespan(-1 * totalMicroseconds), nil } - return microsecondsToTimespan(totalMicroseconds), nil + return t.MicrosecondsToTimespan(totalMicroseconds), nil } } + case decimal.Decimal: + return t.ConvertToTimespan(value.IntPart()) + case decimal.NullDecimal: + if value.Valid { + return t.ConvertToTimespan(value.Decimal.IntPart()) + } case string: impl, err := stringToTimespan(value) if err == nil { @@ -196,26 +206,27 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if strings.Contains(value, ".") { strAsDouble, err := strconv.ParseFloat(value, 64) if err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(v) + return Timespan(0), ErrConvertingToTimeType.New(v) } - return t.ConvertToTimespanImpl(strAsDouble) + return t.ConvertToTimespan(strAsDouble) } else { strAsInt, err := strconv.ParseInt(value, 10, 64) if err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(v) + return Timespan(0), ErrConvertingToTimeType.New(v) } - return t.ConvertToTimespanImpl(strAsInt) + return t.ConvertToTimespan(strAsInt) } case time.Duration: - microseconds := value.Nanoseconds() / 1000 - return microsecondsToTimespan(microseconds), nil + microseconds := value.Nanoseconds() / nanosecondsPerMicrosecond + return t.MicrosecondsToTimespan(microseconds), nil } - return timespanImpl{}, ErrConvertingToTimeType.New(v) + return Timespan(0), ErrConvertingToTimeType.New(v) } +// ConvertToTimeDuration implements the TimeType interface. func (t timespanType) ConvertToTimeDuration(v interface{}) (time.Duration, error) { - val, err := t.ConvertToTimespanImpl(v) + val, err := t.ConvertToTimespan(v) if err != nil { return time.Duration(0), err } @@ -235,12 +246,15 @@ func (t timespanType) Promote() Type { // SQL implements Type interface. func (t timespanType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { - ti, err := t.ConvertToTimespanImpl(v) + if v == nil { + return sqltypes.NULL, nil + } + ti, err := t.ConvertToTimespan(v) if err != nil { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(ti.String())) + val := appendAndSliceString(dest, ti.String()) return sqltypes.MakeTrusted(sqltypes.Time, val), nil } @@ -255,23 +269,14 @@ func (t timespanType) Type() query.Type { return sqltypes.Time } -// Zero implements Type interface. -func (t timespanType) Zero() interface{} { - return "00:00:00" -} - -// Marshal takes a valid Time value and returns it as an int64. -func (t timespanType) Marshal(v interface{}) (int64, error) { - if ti, err := t.ConvertToTimespanImpl(v); err != nil { - return 0, err - } else { - return ti.AsMicroseconds(), nil - } +// ValueType implements Type interface. +func (t timespanType) ValueType() reflect.Type { + return timeValueType } -// Unmarshal takes a previously-marshalled value and returns it as a string. -func (t timespanType) Unmarshal(v int64) string { - return microsecondsToTimespan(v).String() +// Zero implements Type interface. +func (t timespanType) Zero() interface{} { + return Timespan(0) } // No built in for absolute values on int64 @@ -280,10 +285,15 @@ func int64Abs(v int64) int64 { return (v ^ shift) - shift } -func stringToTimespan(s string) (timespanImpl, error) { - impl := timespanImpl{} +func stringToTimespan(s string) (Timespan, error) { + var negative bool + var hours int16 + var minutes int8 + var seconds int8 + var microseconds int32 + if len(s) > 0 && s[0] == '-' { - impl.negative = true + negative = true s = s[1:] } @@ -296,15 +306,15 @@ func stringToTimespan(s string) (timespanImpl, error) { microStr += strings.Repeat("0", 6-len(comps[1])) } microStr, remainStr := microStr[0:6], microStr[6:] - microseconds, err := strconv.Atoi(microStr) + convertedMicroseconds, err := strconv.Atoi(microStr) if err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } // MySQL just uses the last digit to round up. This is weird, but matches their implementation. if len(remainStr) > 0 && remainStr[len(remainStr)-1:] >= "5" { - microseconds++ + convertedMicroseconds++ } - impl.microseconds = int32(microseconds) + microseconds = int32(convertedMicroseconds) } // Parse H-M-S time @@ -312,16 +322,16 @@ func stringToTimespan(s string) (timespanImpl, error) { hms := make([]string, 3) if len(hmsComps) >= 2 { if len(hmsComps[0]) > 3 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } hms[0] = hmsComps[0] if len(hmsComps[1]) > 2 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } hms[1] = hmsComps[1] if len(hmsComps) == 3 { if len(hmsComps[2]) > 2 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } hms[2] = hmsComps[2] } @@ -332,52 +342,52 @@ func stringToTimespan(s string) (timespanImpl, error) { hms[0] = safeSubstr(hmsComps[0], l-7, l-4) } - hours, err := strconv.Atoi(hms[0]) + hmsHours, err := strconv.Atoi(hms[0]) if len(hms[0]) > 0 && err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } - impl.hours = int16(hours) + hours = int16(hmsHours) - minutes, err := strconv.Atoi(hms[1]) + hmsMinutes, err := strconv.Atoi(hms[1]) if len(hms[1]) > 0 && err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) - } else if minutes >= 60 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) + } else if hmsMinutes >= 60 { + return Timespan(0), ErrConvertingToTimeType.New(s) } - impl.minutes = int8(minutes) + minutes = int8(hmsMinutes) - seconds, err := strconv.Atoi(hms[2]) + hmsSeconds, err := strconv.Atoi(hms[2]) if len(hms[2]) > 0 && err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) - } else if seconds >= 60 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) + } else if hmsSeconds >= 60 { + return Timespan(0), ErrConvertingToTimeType.New(s) } - impl.seconds = int8(seconds) + seconds = int8(hmsSeconds) - if impl.microseconds == int32(microsecondsPerSecond) { - impl.microseconds = 0 - impl.seconds++ + if microseconds == int32(microsecondsPerSecond) { + microseconds = 0 + seconds++ } - if impl.seconds == 60 { - impl.seconds = 0 - impl.minutes++ + if seconds == 60 { + seconds = 0 + minutes++ } - if impl.minutes == 60 { - impl.minutes = 0 - impl.hours++ + if minutes == 60 { + minutes = 0 + hours++ } - if impl.hours > 838 { - impl.hours = 838 - impl.minutes = 59 - impl.seconds = 59 + if hours > 838 { + hours = 838 + minutes = 59 + seconds = 59 } - if impl.hours == 838 && impl.minutes == 59 && impl.seconds == 59 { - impl.microseconds = 0 + if hours == 838 && minutes == 59 && seconds == 59 { + microseconds = 0 } - return impl, nil + return unitsToTimespan(negative, hours, minutes, seconds, microseconds), nil } func safeSubstr(s string, start int, end int) string { @@ -396,46 +406,103 @@ func safeSubstr(s string, start int, end int) string { return s[start:end] } -func microsecondsToTimespan(v int64) timespanImpl { +// MicrosecondsToTimespan implements the TimeType interface. +func (_ timespanType) MicrosecondsToTimespan(v int64) Timespan { if v < timespanMinimum { v = timespanMinimum } else if v > timespanMaximum { v = timespanMaximum } + return Timespan(v) +} - absV := int64Abs(v) - - return timespanImpl{ - negative: v < 0, - hours: int16(absV / microsecondsPerHour), - minutes: int8((absV / microsecondsPerMinute) % 60), - seconds: int8((absV / microsecondsPerSecond) % 60), - microseconds: int32(absV % microsecondsPerSecond), +func unitsToTimespan(isNegative bool, hours int16, minutes int8, seconds int8, microseconds int32) Timespan { + negative := int64(1) + if isNegative { + negative = -1 } + return Timespan(negative * + (int64(microseconds) + + (int64(seconds) * microsecondsPerSecond) + + (int64(minutes) * microsecondsPerMinute) + + (int64(hours) * microsecondsPerHour))) } -func (t timespanImpl) String() string { +func (t Timespan) timespanToUnits() (isNegative bool, hours int16, minutes int8, seconds int8, microseconds int32) { + isNegative = t < 0 + absV := int64Abs(int64(t)) + hours = int16(absV / microsecondsPerHour) + minutes = int8((absV / microsecondsPerMinute) % 60) + seconds = int8((absV / microsecondsPerSecond) % 60) + microseconds = int32(absV % microsecondsPerSecond) + return +} + +// String returns the Timespan formatted as a string (such as for display purposes). +func (t Timespan) String() string { + isNegative, hours, minutes, seconds, microseconds := t.timespanToUnits() sign := "" - if t.negative { + if isNegative { sign = "-" } - if t.microseconds == 0 { - return fmt.Sprintf("%v%02d:%02d:%02d", sign, t.hours, t.minutes, t.seconds) + if microseconds == 0 { + return fmt.Sprintf("%v%02d:%02d:%02d", sign, hours, minutes, seconds) } - return fmt.Sprintf("%v%02d:%02d:%02d.%06d", sign, t.hours, t.minutes, t.seconds, t.microseconds) + return fmt.Sprintf("%v%02d:%02d:%02d.%06d", sign, hours, minutes, seconds, microseconds) } -func (t timespanImpl) AsMicroseconds() int64 { - negative := int64(1) - if t.negative { - negative = -1 - } - return negative * (int64(t.microseconds) + - (int64(t.seconds) * microsecondsPerSecond) + - (int64(t.minutes) * microsecondsPerMinute) + - (int64(t.hours) * microsecondsPerHour)) +// AsMicroseconds returns the Timespan in microseconds. +func (t Timespan) AsMicroseconds() int64 { + // Timespan already being implemented in microseconds is an implementation detail that integrators do not need to + // know about. This is also the reason for the comparison functions. + return int64(t) } -func (t timespanImpl) AsTimeDuration() time.Duration { +// AsTimeDuration returns the Timespan as a time.Duration. +func (t Timespan) AsTimeDuration() time.Duration { return time.Duration(t.AsMicroseconds() * nanosecondsPerMicrosecond) } + +// Equals returns whether the calling Timespan and given Timespan are equivalent. +func (t Timespan) Equals(other Timespan) bool { + return t == other +} + +// Compare returns an integer comparing two values. The result will be 0 if t==other, -1 if t < other, and +1 if t > other. +func (t Timespan) Compare(other Timespan) int { + if t < other { + return -1 + } else if t > other { + return 1 + } + return 0 +} + +// Negate returns a new Timespan that has been negated. +func (t Timespan) Negate() Timespan { + return -1 * t +} + +// Add returns a new Timespan that is the sum of the calling Timespan and given Timespan. The resulting Timespan is +// clamped to the allowed range. +func (t Timespan) Add(other Timespan) Timespan { + v := int64(t + other) + if v < timespanMinimum { + v = timespanMinimum + } else if v > timespanMaximum { + v = timespanMaximum + } + return Timespan(v) +} + +// Subtract returns a new Timespan that is the difference of the calling Timespan and given Timespan. The resulting +// Timespan is clamped to the allowed range. +func (t Timespan) Subtract(other Timespan) Timespan { + v := int64(t - other) + if v < timespanMinimum { + v = timespanMinimum + } else if v > timespanMaximum { + v = timespanMaximum + } + return Timespan(v) +} diff --git a/sql/timetype_test.go b/sql/timetype_test.go index 577d72c547..68fce4bb41 100644 --- a/sql/timetype_test.go +++ b/sql/timetype_test.go @@ -166,12 +166,16 @@ func TestTimeConvert(t *testing.T) { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) - if test.val != nil { - mar, err := Time.Marshal(test.val) + if test.val == nil { + assert.Equal(t, test.expectedVal, val) + } else { + assert.Equal(t, test.expectedVal, val.(Timespan).String()) + timespan, err := Time.ConvertToTimespan(test.val) require.NoError(t, err) - umar := Time.Unmarshal(mar) - cmp, err := Time.Compare(test.val, umar) + require.True(t, timespan.Equals(val.(Timespan))) + ms := timespan.AsMicroseconds() + ums := Time.MicrosecondsToTimespan(ms) + cmp, err := Time.Compare(test.val, ums) require.NoError(t, err) assert.Equal(t, 0, cmp) } diff --git a/sql/tupletype.go b/sql/tupletype.go index 095a759f2b..aa31a67b1b 100644 --- a/sql/tupletype.go +++ b/sql/tupletype.go @@ -16,14 +16,19 @@ package sql import ( "fmt" + "reflect" "strings" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var tupleValueType = reflect.TypeOf((*[]interface{})(nil)).Elem() + type TupleType []Type +var _ Type = TupleType{nil} + // CreateTuple returns a new tuple type with the given element types. func CreateTuple(types ...Type) Type { return TupleType(types) @@ -124,6 +129,11 @@ func (t TupleType) Type() query.Type { return sqltypes.Expression } +// ValueType implements Type interface. +func (t TupleType) ValueType() reflect.Type { + return tupleValueType +} + func (t TupleType) Zero() interface{} { zeroes := make([]interface{}, len(t)) for i, tt := range t { diff --git a/sql/type.go b/sql/type.go index aeb6c9a05b..3134309ac2 100644 --- a/sql/type.go +++ b/sql/type.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "io" + "reflect" "strconv" "strings" "time" @@ -66,6 +67,8 @@ type Type interface { SQL(dest []byte, v interface{}) (sqltypes.Value, error) // Type returns the query.Type for the given Type. Type() query.Type + // ValueType returns the Go type of the value returned by Convert(). + ValueType() reflect.Type // Zero returns the golang zero value for this type Zero() interface{} fmt.Stringer @@ -140,7 +143,7 @@ func ApproximateTypeFromValue(val interface{}) Type { return Uint16 case uint8: return Uint8 - case time.Duration: + case Timespan, time.Duration: return Time case time.Time: return Datetime diff --git a/sql/yeartype.go b/sql/yeartype.go index 3b1623e864..13fd3177d7 100644 --- a/sql/yeartype.go +++ b/sql/yeartype.go @@ -15,9 +15,12 @@ package sql import ( + "reflect" "strconv" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -27,10 +30,13 @@ var ( Year YearType = yearType{} ErrConvertingToYear = errors.NewKind("value %v is not a valid Year") + + yearValueType = reflect.TypeOf(int16(0)) ) -// Represents the YEAR type. +// YearType represents the YEAR type. // https://dev.mysql.com/doc/refman/8.0/en/year.html +// The type of the returned value is int16. type YearType interface { Type } @@ -105,6 +111,13 @@ func (t yearType) Convert(v interface{}) (interface{}, error) { return t.Convert(int64(value)) case float64: return t.Convert(int64(value)) + case decimal.Decimal: + return t.Convert(value.IntPart()) + case decimal.NullDecimal: + if !value.Valid { + return nil, nil + } + return t.Convert(value.Decimal.IntPart()) case string: valueLength := len(value) if valueLength == 1 || valueLength == 2 || valueLength == 4 { @@ -175,6 +188,11 @@ func (t yearType) Type() query.Type { return sqltypes.Year } +// ValueType implements Type interface. +func (t yearType) ValueType() reflect.Type { + return yearValueType +} + // Zero implements Type interface. func (t yearType) Zero() interface{} { return int16(0) diff --git a/sql/yeartype_test.go b/sql/yeartype_test.go index 6fe6372b30..339878a874 100644 --- a/sql/yeartype_test.go +++ b/sql/yeartype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" "time" @@ -95,6 +96,9 @@ func TestYearConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, Year.ValueType(), reflect.TypeOf(val)) + } } }) }