diff --git a/benchmark_test.go b/benchmark_test.go index 912e5414..7a6661ae 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -113,6 +113,47 @@ func benchmarkQueryHelper(b *testing.B, compr bool) { } } +func BenchmarkSelect10000rows(b *testing.B) { + db := initDB(b, false) + defer db.Close() + + // Check if we're using MariaDB + var version string + err := db.QueryRow("SELECT @@version").Scan(&version) + if err != nil { + b.Fatalf("Failed to get server version: %v", err) + } + + if !strings.Contains(strings.ToLower(version), "mariadb") { + b.Skip("Skipping benchmark as it requires MariaDB sequence table") + return + } + + b.StartTimer() + stmt, err := db.Prepare("SELECT * FROM seq_1_to_10000") + if err != nil { + b.Fatalf("Failed to prepare statement: %v", err) + } + defer stmt.Close() + for n := 0; n < b.N; n++ { + rows, err := stmt.Query() + if err != nil { + b.Fatalf("Failed to query 10000rows: %v", err) + } + + var id int64 + for rows.Next() { + err = rows.Scan(&id) + if err != nil { + rows.Close() + b.Fatalf("Failed to scan row: %v", err) + } + } + rows.Close() + } + b.StopTimer() +} + func BenchmarkExec(b *testing.B) { tb := (*TB)(b) b.StopTimer() diff --git a/compress_test.go b/compress_test.go index 030deaef..72696a4d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -40,6 +40,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []by conn := new(mockConn) conn.data = compressedPacket mc.netConn = conn + mc.readNextFunc = mc.compIO.readNext + mc.readFunc = conn.Read uncompressedPacket, err := mc.readPacket() if err != nil { diff --git a/connection.go b/connection.go index 3e455a3f..64b5a502 100644 --- a/connection.go +++ b/connection.go @@ -39,6 +39,8 @@ type mysqlConn struct { compressSequence uint8 parseTime bool compress bool + readFunc func([]byte) (int, error) + readNextFunc func(int, readerFunc) ([]byte, error) // for context support (Go 1.8+) watching bool @@ -64,16 +66,6 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } -func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) { - to := mc.cfg.ReadTimeout - if to > 0 { - if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil { - return 0, err - } - } - return mc.netConn.Read(b) -} - func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { to := mc.cfg.WriteTimeout if to > 0 { @@ -247,7 +239,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin // can not take the buffer. Something must be wrong with the connection mc.cleanup() // interpolateParams would be called before sending any query. - // So its safe to retry. + // So it's safe to retry. return "", driver.ErrBadConn } buf = buf[:0] diff --git a/connection_test.go b/connection_test.go index 440ecbff..e1d3c79c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -18,12 +18,17 @@ import ( ) func TestInterpolateParams(t *testing.T) { + buf := newBuffer() + nc := &net.TCPConn{} mc := &mysqlConn{ - buf: newBuffer(), + buf: buf, + netConn: nc, maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, + readNextFunc: buf.readNext, + readFunc: nc.Read, } q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) diff --git a/connector.go b/connector.go index bc1d46af..ea121923 100644 --- a/connector.go +++ b/connector.go @@ -16,6 +16,7 @@ import ( "os" "strconv" "strings" + "time" ) type connector struct { @@ -130,6 +131,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() + // setting readNext/read functions + mc.readNextFunc = mc.buf.readNext + + // Initialize read function based on configuration + if mc.cfg.ReadTimeout > 0 { + mc.readFunc = func(b []byte) (int, error) { + deadline := time.Now().Add(mc.cfg.ReadTimeout) + if err := mc.netConn.SetReadDeadline(deadline); err != nil { + return 0, err + } + return mc.netConn.Read(b) + } + } else { + mc.readFunc = mc.netConn.Read + } + // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() if err != nil { @@ -170,6 +187,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if mc.cfg.compress && mc.flags&clientCompress == clientCompress { mc.compress = true mc.compIO = newCompIO(mc) + mc.readNextFunc = mc.compIO.readNext } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket diff --git a/packets.go b/packets.go index a497a50a..105d1319 100644 --- a/packets.go +++ b/packets.go @@ -30,14 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte invalidSequence := false - readNext := mc.buf.readNext - if mc.compress { - readNext = mc.compIO.readNext - } - for { // read packet header - data, err := readNext(4, mc.readWithTimeout) + data, err := mc.readNextFunc(4, mc.readFunc) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -85,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = readNext(pktLen, mc.readWithTimeout) + data, err = mc.readNextFunc(pktLen, mc.readFunc) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -369,6 +364,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } mc.netConn = tlsConn + mc.readFunc = mc.netConn.Read } // User [null terminated string] diff --git a/packets_test.go b/packets_test.go index 694b0564..71b071a8 100644 --- a/packets_test.go +++ b/packets_test.go @@ -97,24 +97,30 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) + buf := newBuffer() mc := &mysqlConn{ - buf: newBuffer(), + buf: buf, cfg: connector.cfg, connector: connector, netConn: conn, closech: make(chan struct{}), maxAllowedPacket: defaultMaxAllowedPacket, sequence: sequence, + readNextFunc: buf.readNext, + readFunc: conn.Read, } return conn, mc } func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - cfg: NewConfig(), + netConn: conn, + buf: buf, + cfg: NewConfig(), + readNextFunc: buf.readNext, + readFunc: conn.Read, } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -165,10 +171,13 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - cfg: NewConfig(), + netConn: conn, + buf: buf, + cfg: NewConfig(), + readNextFunc: buf.readNext, + readFunc: conn.Read, } data := make([]byte, maxPacketSize*2+4*3) @@ -272,11 +281,14 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - closech: make(chan struct{}), - cfg: NewConfig(), + netConn: conn, + buf: buf, + closech: make(chan struct{}), + cfg: NewConfig(), + readNextFunc: buf.readNext, + readFunc: conn.Read, } // illegal empty (stand-alone) packet @@ -317,12 +329,15 @@ func TestReadPacketFail(t *testing.T) { // not-NUL terminated plugin_name in init packet func TestRegression801(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - cfg: new(Config), - sequence: 42, - closech: make(chan struct{}), + netConn: conn, + buf: buf, + cfg: new(Config), + sequence: 42, + closech: make(chan struct{}), + readNextFunc: buf.readNext, + readFunc: conn.Read, } conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,