Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection attributes #1389

Merged
merged 9 commits into from May 23, 2023
9 changes: 9 additions & 0 deletions README.md
Expand Up @@ -393,6 +393,15 @@ Default: 0

I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.

##### `connectionAttributes`

```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```

[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.

##### System Variables

Expand Down
12 changes: 12 additions & 0 deletions const.go
Expand Up @@ -8,12 +8,24 @@

package mysql

import "runtime"

const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"

// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
methane marked this conversation as resolved.
Show resolved Hide resolved
connAttrClientNameValue = "GO-MySQL-Driver"
methane marked this conversation as resolved.
Show resolved Hide resolved
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
)

// MySQL constants documentation:
Expand Down
55 changes: 55 additions & 0 deletions driver_test.go
Expand Up @@ -3209,3 +3209,58 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
t.Errorf("connection not closed")
}
}

func TestConnectionAttributes(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
}

dbt := &DBTest{t, db}

var version string
if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
dbt.Fatalf("%s", err.Error())
}
if strings.Contains(strings.ToLower(version), "mariadb") {
t.Skip(`TODO: Support adding connection attributes in MariaDB`)
}

var attrValue string
methane marked this conversation as resolved.
Show resolved Hide resolved
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}
38 changes: 22 additions & 16 deletions dsn.go
Expand Up @@ -34,22 +34,23 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down Expand Up @@ -554,6 +555,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}

// Connection attributes
case "connectionAttributes":
cfg.ConnectionAttributes = value

default:
// lazy init
if cfg.Params == nil {
Expand Down
35 changes: 35 additions & 0 deletions packets.go
Expand Up @@ -18,6 +18,9 @@ import (
"fmt"
"io"
"math"
"os"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -285,6 +288,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -318,6 +322,32 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

connAttrsBuf := make([]byte, 0, 100)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))

// user-defined connection attributes
for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") {
attr := strings.Split(connAttr, ":")
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

// 1 byte to store length of all key-values
pktLen += len(connAttrsBuf) + 1

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
Expand Down Expand Up @@ -394,6 +424,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++

// Connection Attributes
data[pos] = byte(len(connAttrsBuf))
pos++
pos += copy(data[pos:], connAttrsBuf)

// Send Auth packet
return mc.writePacket(data[:pos])
}
Expand Down
5 changes: 5 additions & 0 deletions utils.go
Expand Up @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}

func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}

// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
Expand Down