Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TimeTruncate functional option #1552

Merged
merged 7 commits into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion connection.go
Expand Up @@ -251,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
Expand Down
47 changes: 40 additions & 7 deletions dsn.go
Expand Up @@ -34,6 +34,8 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
// non boolean fields

User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Expand All @@ -45,15 +47,15 @@ type Config struct {
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
TimeTruncate time.Duration // Truncate time.Time values to the specified duration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger

// boolean fields

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
Expand All @@ -66,17 +68,48 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections

// unexported fields. new options should be come here

pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
}

// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type Option func(*Config) error

// NewConfig creates a new Config and sets default values.
func NewConfig() *Config {
return &Config{
cfg := &Config{
Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket,
Logger: defaultLogger,
AllowNativePasswords: true,
CheckConnLiveness: true,
}

return cfg
}

// Apply applies the given options to the Config object.
func (c *Config) Apply(opts ...Option) error {
methane marked this conversation as resolved.
Show resolved Hide resolved
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
return nil
}

// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func TimeTruncate(d time.Duration) Option {
return func(cfg *Config) error {
cfg.timeTruncate = d
return nil
}
}

func (cfg *Config) Clone() *Config {
Expand Down Expand Up @@ -263,8 +296,8 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "parseTime", "true")
}

if cfg.TimeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.TimeTruncate.String())
if cfg.timeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
}

if cfg.ReadTimeout > 0 {
Expand Down Expand Up @@ -509,9 +542,9 @@ func parseDSNParams(cfg *Config, params string) (err error) {

// time.Time truncation
case "timeTruncate":
cfg.TimeTruncate, err = time.ParseDuration(value)
cfg.timeTruncate, err = time.ParseDuration(value)
methane marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
}

// I/O read Timeout
Expand Down
2 changes: 1 addition & 1 deletion dsn_test.go
Expand Up @@ -76,7 +76,7 @@ var testDSNs = []struct {
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TimeTruncate: time.Hour},
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour},
},
}

methane marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Expand Up @@ -1172,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions result.go
Expand Up @@ -15,9 +15,8 @@ import "database/sql/driver"
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
//
// res, err := rawConn.Exec(...)
methane marked this conversation as resolved.
Show resolved Hide resolved
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
Expand Down