Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exec() now provides access to status of multiple statements. #1309

Merged
merged 5 commits into from May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Expand Up @@ -305,6 +305,22 @@ Allow multiple statements in one query. This can be used to bach multiple querie

When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly.

It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:

```go
conn, _ := db.Conn(ctx)
conn.Raw(func(conn interface{}) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
UPDATE point SET x = 1 WHERE y = 2;
UPDATE point SET x = 2 WHERE y = 3;
`, nil)
// Both slices have 2 elements.
log.Print(res.(mysql.Result).AllRowsAffected())
log.Print(res.(mysql.Result).AllLastInsertIds())
})
```

##### `parseTime`

```
Expand Down
6 changes: 3 additions & 3 deletions auth.go
Expand Up @@ -346,7 +346,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil {
if err = mc.resultUnchanged().readResultOK(); err == nil {
return nil // auth successful
}

Expand Down Expand Up @@ -397,7 +397,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err
}
}
return mc.readResultOK()
return mc.resultUnchanged().readResultOK()

default:
return ErrMalformPkt
Expand Down Expand Up @@ -426,7 +426,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
if err != nil {
return err
}
return mc.readResultOK()
return mc.resultUnchanged().readResultOK()
}

default:
Expand Down
29 changes: 15 additions & 14 deletions connection.go
Expand Up @@ -23,9 +23,8 @@ import (
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
cfg *Config
connector *connector
maxAllowedPacket int
Expand Down Expand Up @@ -155,6 +154,7 @@ func (mc *mysqlConn) cleanup() {
if err := mc.netConn.Close(); err != nil {
mc.cfg.Logger.Print(err)
}
mc.clearResult()
}

func (mc *mysqlConn) error() error {
Expand Down Expand Up @@ -316,28 +316,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertId = 0

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
Expand All @@ -354,14 +351,16 @@ func (mc *mysqlConn) exec(query string) error {
}
}

return mc.discardResults()
return handleOk.discardResults()
}

func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}

func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
handleOk := mc.clearResult()

if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -382,7 +381,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
resLen, err = handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand Down Expand Up @@ -410,12 +409,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
handleOk := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand Down Expand Up @@ -466,11 +466,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
}
defer mc.finish()

handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}

return mc.readResultOK()
return handleOk.readResultOK()
}

// BeginTx implements driver.ConnBeginTx interface
Expand Down
112 changes: 112 additions & 0 deletions driver_test.go
Expand Up @@ -2154,11 +2154,51 @@ func TestRejectReadOnly(t *testing.T) {
}

func TestPing(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.fail("Ping", "Ping", err)
}
})

runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}

// Check that affectedRows and insertIds are cleared after each call.
conn.Raw(func(conn interface{}) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.Query(`SELECT 1`, nil)
if err != nil {
dbt.fail("Conn", "Query", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
}
q.Close()

// Verify that Ping() clears both fields.
for i := 0; i < 2; i++ {
if err := c.Ping(ctx); err != nil {
dbt.fail("Pinger", "Ping", err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
}
return nil
})
})
}

// See Issue #799
Expand Down Expand Up @@ -2378,6 +2418,42 @@ func TestMultiResultSetNoSelect(t *testing.T) {
})
}

func TestExecMultipleResults(t *testing.T) {
ctx := context.Background()
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
dbt.mustExec(`
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
value VARCHAR(255),
PRIMARY KEY (id)
)`)
conn, err := dbt.db.Conn(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
conn.Raw(func(conn interface{}) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
INSERT INTO test (value) VALUES ('a'), ('b');
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
`, nil)
if err != nil {
t.Fatalf("insert statements failed: %v", err)
}
mres := res.(Result)
if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want)
}
// For INSERTs containing multiple rows, LAST_INSERT_ID() returns the
// first inserted ID, not the last.
if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want)
}
return nil
})
})
}

// tests if rows are set in a proper state if some results were ignored before
// calling rows.NextResultSet.
func TestSkipResults(t *testing.T) {
Expand All @@ -2399,6 +2475,42 @@ func TestSkipResults(t *testing.T) {
})
}

func TestQueryMultipleResults(t *testing.T) {
ctx := context.Background()
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
dbt.mustExec(`
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
value VARCHAR(255),
PRIMARY KEY (id)
)`)
conn, err := dbt.db.Conn(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
conn.Raw(func(conn interface{}) error {
qr := conn.(driver.Queryer)

c := conn.(*mysqlConn)

// Demonstrate that repeated queries reset the affectedRows
for i := 0; i < 2; i++ {
_, err := qr.Query(`
INSERT INTO test (value) VALUES ('a'), ('b');
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
`, nil)
if err != nil {
t.Fatalf("insert statements failed: %v", err)
}
if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
}
return nil
})
})
}

func TestPingContext(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
8 changes: 4 additions & 4 deletions infile.go
Expand Up @@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {

const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP

func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := defaultPacketSize
Expand Down Expand Up @@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
Expand All @@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr
}

Expand All @@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
return mc.readResultOK()
}

mc.readPacket()
mc.conn().readPacket()
return err
}