Skip to content

Commit

Permalink
Make TimeTruncate functional option (#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Mar 6, 2024
1 parent 097fe6e commit 6964272
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
2 changes: 1 addition & 1 deletion connection.go
Expand Up @@ -251,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
Expand Down
47 changes: 40 additions & 7 deletions dsn.go
Expand Up @@ -34,6 +34,8 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
// non boolean fields

User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Expand All @@ -45,15 +47,15 @@ type Config struct {
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
TimeTruncate time.Duration // Truncate time.Time values to the specified duration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger

// boolean fields

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
Expand All @@ -66,17 +68,48 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections

// unexported fields. new options should be come here

pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
}

// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type Option func(*Config) error

// NewConfig creates a new Config and sets default values.
func NewConfig() *Config {
return &Config{
cfg := &Config{
Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket,
Logger: defaultLogger,
AllowNativePasswords: true,
CheckConnLiveness: true,
}

return cfg
}

// Apply applies the given options to the Config object.
func (c *Config) Apply(opts ...Option) error {
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
return nil
}

// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func TimeTruncate(d time.Duration) Option {
return func(cfg *Config) error {
cfg.timeTruncate = d
return nil
}
}

func (cfg *Config) Clone() *Config {
Expand Down Expand Up @@ -263,8 +296,8 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "parseTime", "true")
}

if cfg.TimeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.TimeTruncate.String())
if cfg.timeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
}

if cfg.ReadTimeout > 0 {
Expand Down Expand Up @@ -509,9 +542,9 @@ func parseDSNParams(cfg *Config, params string) (err error) {

// time.Time truncation
case "timeTruncate":
cfg.TimeTruncate, err = time.ParseDuration(value)
cfg.timeTruncate, err = time.ParseDuration(value)
if err != nil {
return
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
}

// I/O read Timeout
Expand Down
2 changes: 1 addition & 1 deletion dsn_test.go
Expand Up @@ -76,7 +76,7 @@ var testDSNs = []struct {
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TimeTruncate: time.Hour},
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour},
},
}

Expand Down
2 changes: 1 addition & 1 deletion packets.go
Expand Up @@ -1172,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions result.go
Expand Up @@ -15,9 +15,8 @@ import "database/sql/driver"
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
Expand Down

0 comments on commit 6964272

Please sign in to comment.