Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BeforeConnect callback to configuration object #1469

Merged
merged 6 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Microsoft Corp.
Multiplay Ltd.
Percona LLC
Pivotal Inc.
Expand Down
12 changes: 11 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,22 @@ func newConnector(cfg *Config) (*connector, error) {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error

// Invoke BeforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.BeforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.BeforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
}

// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
cfg: cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime
Expand Down
34 changes: 34 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,40 @@ func TestCustomDial(t *testing.T) {
}
}

func TestBeforeConnect(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

// dbname is set in the BeforeConnect handle
cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_"))
if err != nil {
t.Fatalf("error parsing DSN: %v", err)
}

cfg.BeforeConnect = func(ctx context.Context, c *Config) error {
c.DBName = dbname
return nil
}

connector, err := NewConnector(cfg)
if err != nil {
t.Fatalf("error creating connector: %v", err)
}

db := sql.OpenDB(connector)
defer db.Close()

var connectedDb string
err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb)
if err != nil {
t.Fatalf("error executing query: %v", err)
}
if connectedDb != dbname {
t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb)
}
}

func TestSQLInjection(t *testing.T) {
createTest := func(arg string) func(dbt *DBTest) {
return func(dbt *DBTest) {
Expand Down
3 changes: 3 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -65,6 +66,8 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections

BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established
ItalyPaleAle marked this conversation as resolved.
Show resolved Hide resolved
}

// NewConfig creates a new Config and sets default values.
Expand Down