Skip to content

Commit

Permalink
Use PathEscape for dbname in DSN. (#1432)
Browse files Browse the repository at this point in the history
Support for slashes in database names via url escape codes.
On the other hand, '%' in DSN is now treated as percent-encoding.

Co-authored-by: Brian Hendriks <brian@dolthub.com>
  • Loading branch information
methane and bheni committed May 25, 2023
1 parent 924f833 commit d3e4fe6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 31 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Expand Up @@ -110,6 +110,7 @@ Xuehong Chan <chanxuehong at gmail.com>
Zhenye Xie <xiezhenye at gmail.com>
Zhixin Wen <john.wenzhixin at gmail.com>
Ziheng Lyu <zihenglv at gmail.com>
Brian Hendriks <brian at dolthub.com>

# Organizations

Expand All @@ -127,3 +128,4 @@ Percona LLC
Pivotal Inc.
Stripe Inc.
Zendesk Inc.
Dolthub Inc.
6 changes: 6 additions & 0 deletions README.md
Expand Up @@ -114,6 +114,12 @@ This has the same effect as an empty DSN string:
```

`dbname` is escaped by [PathEscape()]()https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes:

```
/dbname%2Fwithslash
```

Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.

#### Password
Expand Down
8 changes: 6 additions & 2 deletions dsn.go
Expand Up @@ -203,7 +203,7 @@ func (cfg *Config) FormatDSN() string {

// /dbname
buf.WriteByte('/')
buf.WriteString(cfg.DBName)
buf.WriteString(url.PathEscape(cfg.DBName))

// [?param1=value1&...&paramN=valueN]
hasParam := false
Expand Down Expand Up @@ -365,7 +365,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
break
}
}
cfg.DBName = dsn[i+1 : j]

dbname := dsn[i+1 : j]
if cfg.DBName, err = url.PathUnescape(dbname); err != nil {
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
}

break
}
Expand Down
66 changes: 37 additions & 29 deletions dsn_test.go
Expand Up @@ -50,6 +50,9 @@ var testDSNs = []struct {
}, {
"/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"/dbname%2Fwithslash",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"@/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
Expand All @@ -76,17 +79,20 @@ var testDSNs = []struct {

func TestDSNParser(t *testing.T) {
for i, tst := range testDSNs {
cfg, err := ParseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
t.Run(tst.in, func(t *testing.T) {
cfg, err := ParseDSN(tst.in)
if err != nil {
t.Error(err.Error())
return
}

// pointer not static
cfg.TLS = nil
// pointer not static
cfg.TLS = nil

if !reflect.DeepEqual(cfg, tst.out) {
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
}
if !reflect.DeepEqual(cfg, tst.out) {
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
}
})
}
}

Expand All @@ -113,27 +119,29 @@ func TestDSNParserInvalid(t *testing.T) {

func TestDSNReformat(t *testing.T) {
for i, tst := range testDSNs {
dsn1 := tst.in
cfg1, err := ParseDSN(dsn1)
if err != nil {
t.Error(err.Error())
continue
}
cfg1.TLS = nil // pointer not static
res1 := fmt.Sprintf("%+v", cfg1)

dsn2 := cfg1.FormatDSN()
cfg2, err := ParseDSN(dsn2)
if err != nil {
t.Error(err.Error())
continue
}
cfg2.TLS = nil // pointer not static
res2 := fmt.Sprintf("%+v", cfg2)
t.Run(tst.in, func(t *testing.T) {
dsn1 := tst.in
cfg1, err := ParseDSN(dsn1)
if err != nil {
t.Error(err.Error())
return
}
cfg1.TLS = nil // pointer not static
res1 := fmt.Sprintf("%+v", cfg1)

if res1 != res2 {
t.Errorf("%d. %q does not match %q", i, res2, res1)
}
dsn2 := cfg1.FormatDSN()
cfg2, err := ParseDSN(dsn2)
if err != nil {
t.Error(err.Error())
return
}
cfg2.TLS = nil // pointer not static
res2 := fmt.Sprintf("%+v", cfg2)

if res1 != res2 {
t.Errorf("%d. %q does not match %q", i, res2, res1)
}
})
}
}

Expand Down

0 comments on commit d3e4fe6

Please sign in to comment.