Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ Default: 0

I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.

##### `connectionAttributes`

```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```

[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.

##### System Variables

Expand Down
11 changes: 11 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,23 @@

package mysql

import "runtime"

const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"

// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
connAttrClientNameValue = "GO-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
)

// MySQL constants documentation:
Expand Down
55 changes: 55 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3209,3 +3209,58 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
t.Errorf("connection not closed")
}
}

func TestConnectionAttributes(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
}

dbt := &DBTest{t, db}

var version string
if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
dbt.Fatalf("%s", err.Error())
}
if strings.Contains(strings.ToLower(version), "mariadb") {
t.Skip(`TODO: Support adding connection attributes in MariaDB`)
}

var attrValue string
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}
38 changes: 22 additions & 16 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down Expand Up @@ -554,6 +555,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}

// Connection attributes
case "connectionAttributes":
cfg.ConnectionAttributes = value

default:
// lazy init
if cfg.Params == nil {
Expand Down
31 changes: 31 additions & 0 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"io"
"math"
"strings"
"time"
)

Expand Down Expand Up @@ -285,6 +286,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -318,6 +320,30 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

connAttrsBuf := make([]byte, 0, 100)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)

// user-defined connection attributes
for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") {
attr := strings.Split(connAttr, ":")
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

// 1 byte to store length of all key-values
pktLen += len(connAttrsBuf) + 1

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
Expand Down Expand Up @@ -394,6 +420,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++

// Connection Attributes
data[pos] = byte(len(connAttrsBuf))
pos++
pos += copy(data[pos:], connAttrsBuf)

// Send Auth packet
return mc.writePacket(data[:pos])
}
Expand Down
5 changes: 5 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}

func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}

// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
Expand Down