mirror of https://github.com/jackc/pgx.git
stdlib: allow nested database/sql/driver.Drivers
database/sql/driver.Driver implementations can be nested, with each layer adding additional functionality. If pgx/stdlib.Driver is wrapped in another driver.Driver implementation, AcquireConn will error, detecting that the *sql.DB's driver is not (directly) pgx.Driver. It looks like it should be possible to support the current functionality without requiring that the top-level Driver be pgx/stdlib.Driver, but it requires using a global map of fakeTxConns instead of a per-Driver map of fakeTxConns. Is this reasonable?pull/410/head
parent
6556ef67cb
commit
14f1f2aa01
|
@ -99,9 +99,9 @@ var ErrNotPgx = errors.New("not pgx *sql.DB")
|
|||
|
||||
func init() {
|
||||
pgxDriver = &Driver{
|
||||
configs: make(map[int64]*DriverConfig),
|
||||
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
|
||||
configs: make(map[int64]*DriverConfig),
|
||||
}
|
||||
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
||||
sql.Register("pgx", pgxDriver)
|
||||
|
||||
databaseSqlOIDs = make(map[pgtype.OID]bool)
|
||||
|
@ -120,13 +120,15 @@ func init() {
|
|||
databaseSqlOIDs[pgtype.XIDOID] = true
|
||||
}
|
||||
|
||||
var (
|
||||
fakeTxMutex sync.Mutex
|
||||
fakeTxConns map[*pgx.Conn]*sql.Tx
|
||||
)
|
||||
|
||||
type Driver struct {
|
||||
configMutex sync.Mutex
|
||||
configCount int64
|
||||
configs map[int64]*DriverConfig
|
||||
|
||||
fakeTxMutex sync.Mutex
|
||||
fakeTxConns map[*pgx.Conn]*sql.Tx
|
||||
}
|
||||
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
|
@ -571,21 +573,20 @@ func (fakeTx) Commit() error { return nil }
|
|||
func (fakeTx) Rollback() error { return nil }
|
||||
|
||||
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
|
||||
driver, ok := db.Driver().(*Driver)
|
||||
if !ok {
|
||||
return nil, ErrNotPgx
|
||||
}
|
||||
|
||||
var conn *pgx.Conn
|
||||
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conn == nil {
|
||||
tx.Rollback()
|
||||
return nil, ErrNotPgx
|
||||
}
|
||||
|
||||
driver.fakeTxMutex.Lock()
|
||||
driver.fakeTxConns[conn] = tx
|
||||
driver.fakeTxMutex.Unlock()
|
||||
fakeTxMutex.Lock()
|
||||
fakeTxConns[conn] = tx
|
||||
fakeTxMutex.Unlock()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
@ -594,14 +595,13 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
|
|||
var tx *sql.Tx
|
||||
var ok bool
|
||||
|
||||
driver := db.Driver().(*Driver)
|
||||
driver.fakeTxMutex.Lock()
|
||||
tx, ok = driver.fakeTxConns[conn]
|
||||
fakeTxMutex.Lock()
|
||||
tx, ok = fakeTxConns[conn]
|
||||
if ok {
|
||||
delete(driver.fakeTxConns, conn)
|
||||
driver.fakeTxMutex.Unlock()
|
||||
delete(fakeTxConns, conn)
|
||||
fakeTxMutex.Unlock()
|
||||
} else {
|
||||
driver.fakeTxMutex.Unlock()
|
||||
fakeTxMutex.Unlock()
|
||||
return errors.Errorf("can't release conn that is not acquired")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue