Skip to content

Commit

Permalink
Send connection attributes (go-sql-driver#1389)
Browse files Browse the repository at this point in the history
Co-authored-by: Inada Naoki <songofacandy@gmail.com>
  • Loading branch information
2 people authored and oblitorum committed Oct 30, 2023
1 parent f20b286 commit e8a6f76
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 27 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Expand Up @@ -84,6 +84,7 @@ jobs:
; TestConcurrent fails if max_connections is too large
max_connections=50
local_infile=1
performance_schema=on
- name: setup database
run: |
mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;'
Expand Down
9 changes: 9 additions & 0 deletions README.md
Expand Up @@ -393,6 +393,15 @@ Default: 0

I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.

##### `connectionAttributes`

```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```

[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.

##### System Variables

Expand Down
1 change: 1 addition & 0 deletions connection.go
Expand Up @@ -27,6 +27,7 @@ type mysqlConn struct {
affectedRows uint64
insertId uint64
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
Expand Down
46 changes: 45 additions & 1 deletion connector.go
Expand Up @@ -11,11 +11,54 @@ package mysql
import (
"context"
"database/sql/driver"
"fmt"
"net"
"os"
"strconv"
"strings"
)

type connector struct {
cfg *Config // immutable private copy.
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
}

func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf := make([]byte, 0, 251)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))

// user-defined connection attributes
for _, connAttr := range strings.Split(textAttributes, ",") {
attr := strings.SplitN(connAttr, ":", 2)
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

return string(connAttrsBuf)
}

func newConnector(cfg *Config) (*connector, error) {
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
if len(encodedAttributes) > 250 {
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
}
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}, nil
}

// Connect implements driver.Connector interface.
Expand All @@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime

Expand Down
9 changes: 6 additions & 3 deletions connector_test.go
Expand Up @@ -8,13 +8,16 @@ import (
)

func TestConnectorReturnsTimeout(t *testing.T) {
connector := &connector{&Config{
connector, err := newConnector(&Config{
Net: "tcp",
Addr: "1.1.1.1:1234",
Timeout: 10 * time.Millisecond,
}}
})
if err != nil {
t.Fatal(err)
}

_, err := connector.Connect(context.Background())
_, err = connector.Connect(context.Background())
if err == nil {
t.Fatal("error expected")
}
Expand Down
12 changes: 12 additions & 0 deletions const.go
Expand Up @@ -8,12 +8,24 @@

package mysql

import "runtime"

const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"

// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
connAttrClientNameValue = "Go-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
)

// MySQL constants documentation:
Expand Down
11 changes: 5 additions & 6 deletions driver.go
Expand Up @@ -74,8 +74,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c := &connector{
cfg: cfg,
c, err := newConnector(cfg)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
Expand All @@ -92,7 +93,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
return &connector{cfg: cfg}, nil
return newConnector(cfg)
}

// OpenConnector implements driver.DriverContext.
Expand All @@ -101,7 +102,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
return newConnector(cfg)
}
47 changes: 47 additions & 0 deletions driver_test.go
Expand Up @@ -3209,3 +3209,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
t.Errorf("connection not closed")
}
}

func TestConnectionAttributes(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
}

dbt := &DBTest{t, db}

var attrValue string
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}
39 changes: 23 additions & 16 deletions dsn.go
Expand Up @@ -34,22 +34,24 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down Expand Up @@ -554,6 +556,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}

// Connection attributes
case "connectionAttributes":
cfg.ConnectionAttributes = value

default:
// lazy init
if cfg.Params == nil {
Expand Down
13 changes: 13 additions & 0 deletions packets.go
Expand Up @@ -285,6 +285,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -318,6 +319,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

// 1 byte to store length of all key-values
// NOTE: Actually, this is length encoded integer.
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
// doesn't support buffer size more than 4096 bytes.
// TODO(methane): Rewrite buffer management.
pktLen += 1 + len(mc.connector.encodedAttributes)

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
Expand Down Expand Up @@ -394,6 +402,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++

// Connection Attributes
data[pos] = byte(len(mc.connector.encodedAttributes))
pos++
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))

// Send Auth packet
return mc.writePacket(data[:pos])
}
Expand Down
7 changes: 6 additions & 1 deletion packets_test.go
Expand Up @@ -96,9 +96,14 @@ var _ net.Conn = new(mockConn)

func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector, err := newConnector(NewConfig())
if err != nil {
panic(err)
}
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
cfg: connector.cfg,
connector: connector,
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
Expand Down
5 changes: 5 additions & 0 deletions utils.go
Expand Up @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}

func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}

// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
Expand Down

0 comments on commit e8a6f76

Please sign in to comment.