Skip to content

Commit

Permalink
feat: rm GetDBConnWithContext method (#6535)
Browse files Browse the repository at this point in the history
* feat: rm contextconnpool method

* feat: nil
  • Loading branch information
qqxhb committed Aug 19, 2023
1 parent bae684b commit fef4294
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 16 deletions.
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
9 changes: 6 additions & 3 deletions gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
Expand Down Expand Up @@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error {
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool

if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(db)
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}

if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
Expand Down
6 changes: 0 additions & 6 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}

// GetDBConnectorWithContext represents SQL db connector which takes into
// account the current database context
type GetDBConnectorWithContext interface {
GetDBConnWithContext(db *DB) (*sql.DB, error)
}

// Rows rows interface
type Rows interface {
Columns() ([]string, error)
Expand Down
23 changes: 18 additions & 5 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
}
}

func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) {
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}

if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(gormdb)
}

if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
Expand Down Expand Up @@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}

beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}

connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}

Expand Down Expand Up @@ -176,6 +185,10 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB
}

func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}

func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
Expand Down

0 comments on commit fef4294

Please sign in to comment.