Skip to content

WIP: Add support for OK packets representing EOF #962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {

// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if err != nil {
return stmt, err
}

if columnCount > 0 {
err = mc.readUntilEOF()
}
if err := mc.readPackets(stmt.paramCount); err != nil {
return nil, err
}

if err := mc.readPackets(int(columnCount)); err != nil {
return nil, err
}

return stmt, err
Expand Down Expand Up @@ -424,11 +424,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
if err := mc.readPackets(resLen); err != nil {
return nil, err
}

dest := make([]driver.Value, resLen)
Expand Down
95 changes: 73 additions & 22 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2

// capability flags (upper 2 bytes) [2 bytes]
mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += +1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -275,6 +280,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientDeprecateEOF |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -599,18 +605,19 @@ func readStatus(b []byte) statusFlag {
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int

// 0x00 [1 byte]

// 0x00 or 0xFE [1 byte]
n := 1
var l int
// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
mc.affectedRows, _, l = readLengthEncodedInteger(data[n:])
n += l

// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
mc.insertId, _, l = readLengthEncodedInteger(data[n:])
n += l

// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
mc.status = readStatus(data[n : n+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
Expand All @@ -620,19 +627,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
return nil
}

// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet
// acting as an EOF.
func isEOFPacket(data []byte) bool {
return data[0] == iEOF && len(data) < 9
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)

for i := 0; ; i++ {
for i := 0; i < count; i++ {
data, err := mc.readPacket()
if err != nil {
return nil, err
}

// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) {
if i == count {
return columns, nil
}
Expand Down Expand Up @@ -718,9 +730,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
return columns, nil
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
// Read Packets as Field Packets until EOF/OK-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
Expand All @@ -735,9 +748,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
}

// EOF Packet
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
if isEOFPacket(data) {
if mc.flags&clientDeprecateEOF == 0 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
} else {
if err := mc.handleOkPacket(data); err != nil {
return err
}
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
Expand Down Expand Up @@ -797,18 +816,44 @@ func (mc *mysqlConn) readUntilEOF() error {
return err
}

switch data[0] {
case iERR:
switch {
case data[0] == iERR:
return mc.handleErrorPacket(data)
case iEOF:
if len(data) == 5 {
case isEOFPacket(data):
if mc.flags&clientDeprecateEOF == 0 {
mc.status = readStatus(data[3:])
} else {
return mc.handleOkPacket(data)
}
return nil
}
}
}

func (mc *mysqlConn) readPackets(num int) error {

// we need to read EOF as well
if mc.flags&clientDeprecateEOF == 0 {
num++
}

for i := 0; i < num; i++ {
data, err := mc.readPacket()
if err != nil {
return err
}

switch {
case data[0] == iERR:
return mc.handleErrorPacket(data)
case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data):
mc.status = readStatus(data[3:])
return nil
}
}
return nil
}

/******************************************************************************
* Prepared Statements *
******************************************************************************/
Expand Down Expand Up @@ -1161,15 +1206,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {

// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
if isEOFPacket(data) {
if rows.mc.flags&clientDeprecateEOF == 0 {
rows.mc.status = readStatus(data[3:])
} else {
if err := rows.mc.handleOkPacket(data); err != nil {
return err
}
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
return io.EOF
}

mc := rows.mc
rows.mc = nil

Expand Down