stdlib: allow nested database/sql/driver.Drivers

database/sql/driver.Driver implementations can be nested, with each
layer adding additional functionality. If pgx/stdlib.Driver is wrapped
in another driver.Driver implementation, AcquireConn will error,
detecting that the *sql.DB's driver is not (directly) pgx.Driver.

It looks like it should be possible to support the current functionality
without requiring that the top-level Driver be pgx/stdlib.Driver, but
it requires using a global map of fakeTxConns instead of a per-Driver
map of fakeTxConns.

Is this reasonable?
pull/410/head
Jackson Owens 2018-04-11 19:20:00 -07:00
parent 6556ef67cb
commit 14f1f2aa01
1 changed files with 19 additions and 19 deletions

View File

@ -100,8 +100,8 @@ var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() { func init() {
pgxDriver = &Driver{ pgxDriver = &Driver{
configs: make(map[int64]*DriverConfig), configs: make(map[int64]*DriverConfig),
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
} }
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver) sql.Register("pgx", pgxDriver)
databaseSqlOIDs = make(map[pgtype.OID]bool) databaseSqlOIDs = make(map[pgtype.OID]bool)
@ -120,13 +120,15 @@ func init() {
databaseSqlOIDs[pgtype.XIDOID] = true databaseSqlOIDs[pgtype.XIDOID] = true
} }
var (
fakeTxMutex sync.Mutex
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type Driver struct { type Driver struct {
configMutex sync.Mutex configMutex sync.Mutex
configCount int64 configCount int64
configs map[int64]*DriverConfig configs map[int64]*DriverConfig
fakeTxMutex sync.Mutex
fakeTxConns map[*pgx.Conn]*sql.Tx
} }
func (d *Driver) Open(name string) (driver.Conn, error) { func (d *Driver) Open(name string) (driver.Conn, error) {
@ -571,21 +573,20 @@ func (fakeTx) Commit() error { return nil }
func (fakeTx) Rollback() error { return nil } func (fakeTx) Rollback() error { return nil }
func AcquireConn(db *sql.DB) (*pgx.Conn, error) { func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
driver, ok := db.Driver().(*Driver)
if !ok {
return nil, ErrNotPgx
}
var conn *pgx.Conn var conn *pgx.Conn
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if conn == nil {
tx.Rollback()
return nil, ErrNotPgx
}
driver.fakeTxMutex.Lock() fakeTxMutex.Lock()
driver.fakeTxConns[conn] = tx fakeTxConns[conn] = tx
driver.fakeTxMutex.Unlock() fakeTxMutex.Unlock()
return conn, nil return conn, nil
} }
@ -594,14 +595,13 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
var tx *sql.Tx var tx *sql.Tx
var ok bool var ok bool
driver := db.Driver().(*Driver) fakeTxMutex.Lock()
driver.fakeTxMutex.Lock() tx, ok = fakeTxConns[conn]
tx, ok = driver.fakeTxConns[conn]
if ok { if ok {
delete(driver.fakeTxConns, conn) delete(fakeTxConns, conn)
driver.fakeTxMutex.Unlock() fakeTxMutex.Unlock()
} else { } else {
driver.fakeTxMutex.Unlock() fakeTxMutex.Unlock()
return errors.Errorf("can't release conn that is not acquired") return errors.Errorf("can't release conn that is not acquired")
} }