diff --git a/connection.go b/connection.go index 565a5480a..10cdb5d34 100644 --- a/connection.go +++ b/connection.go @@ -163,16 +163,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Read Result columnCount, err := stmt.readPrepareResultPacket() - if err == nil { - if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { - return nil, err - } - } + if err != nil { + return stmt, err + } - if columnCount > 0 { - err = mc.readUntilEOF() - } + if err := mc.readPackets(stmt.paramCount); err != nil { + return nil, err + } + + if err := mc.readPackets(int(columnCount)); err != nil { + return nil, err } return stmt, err @@ -424,11 +424,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} - if resLen > 0 { - // Columns - if err := mc.readUntilEOF(); err != nil { - return nil, err - } + if err := mc.readPackets(resLen); err != nil { + return nil, err } dest := make([]driver.Value, resLen) diff --git a/packets.go b/packets.go index cbed325f4..7d15e221c 100644 --- a/packets.go +++ b/packets.go @@ -224,10 +224,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 + // capability flags (upper 2 bytes) [2 bytes] + mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + pos += 2 + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += +1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -275,6 +280,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + mc.flags&clientDeprecateEOF | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -599,18 +605,19 @@ func readStatus(b []byte) statusFlag { // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { - var n, m int - - // 0x00 [1 byte] - + // 0x00 or 0xFE [1 byte] + n := 1 + var l int // Affected rows [Length Coded Binary] - mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + mc.affectedRows, _, l = readLengthEncodedInteger(data[n:]) + n += l // Insert id [Length Coded Binary] - mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + mc.insertId, _, l = readLengthEncodedInteger(data[n:]) + n += l // server_status [2 bytes] - mc.status = readStatus(data[1+n+m : 1+n+m+2]) + mc.status = readStatus(data[n : n+2]) if mc.status&statusMoreResultsExists != 0 { return nil } @@ -620,19 +627,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { return nil } +// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet +// acting as an EOF. +func isEOFPacket(data []byte) bool { + return data[0] == iEOF && len(data) < 9 +} + // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := 0; i < count; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) { if i == count { return columns, nil } @@ -718,9 +730,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } + return columns, nil } -// Read Packets as Field Packets until EOF-Packet or an Error appears +// Read Packets as Field Packets until EOF/OK-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc @@ -735,9 +748,15 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if isEOFPacket(data) { + if mc.flags&clientDeprecateEOF == 0 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + } else { + if err := mc.handleOkPacket(data); err != nil { + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -797,18 +816,44 @@ func (mc *mysqlConn) readUntilEOF() error { return err } - switch data[0] { - case iERR: + switch { + case data[0] == iERR: return mc.handleErrorPacket(data) - case iEOF: - if len(data) == 5 { + case isEOFPacket(data): + if mc.flags&clientDeprecateEOF == 0 { mc.status = readStatus(data[3:]) + } else { + return mc.handleOkPacket(data) } return nil } } } +func (mc *mysqlConn) readPackets(num int) error { + + // we need to read EOF as well + if mc.flags&clientDeprecateEOF == 0 { + num++ + } + + for i := 0; i < num; i++ { + data, err := mc.readPacket() + if err != nil { + return err + } + + switch { + case data[0] == iERR: + return mc.handleErrorPacket(data) + case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data): + mc.status = readStatus(data[3:]) + return nil + } + } + return nil +} + /****************************************************************************** * Prepared Statements * ******************************************************************************/ @@ -1161,15 +1206,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + if isEOFPacket(data) { + if rows.mc.flags&clientDeprecateEOF == 0 { + rows.mc.status = readStatus(data[3:]) + } else { + if err := rows.mc.handleOkPacket(data); err != nil { + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } + mc := rows.mc rows.mc = nil