From 14f1f2aa0184ccec7bedbc493a18d81af08520c6 Mon Sep 17 00:00:00 2001 From: Jackson Owens Date: Wed, 11 Apr 2018 19:20:00 -0700 Subject: [PATCH] stdlib: allow nested database/sql/driver.Drivers database/sql/driver.Driver implementations can be nested, with each layer adding additional functionality. If pgx/stdlib.Driver is wrapped in another driver.Driver implementation, AcquireConn will error, detecting that the *sql.DB's driver is not (directly) pgx.Driver. It looks like it should be possible to support the current functionality without requiring that the top-level Driver be pgx/stdlib.Driver, but it requires using a global map of fakeTxConns instead of a per-Driver map of fakeTxConns. Is this reasonable? --- stdlib/sql.go | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 2d4930ee..76a511b1 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -99,9 +99,9 @@ var ErrNotPgx = errors.New("not pgx *sql.DB") func init() { pgxDriver = &Driver{ - configs: make(map[int64]*DriverConfig), - fakeTxConns: make(map[*pgx.Conn]*sql.Tx), + configs: make(map[int64]*DriverConfig), } + fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) databaseSqlOIDs = make(map[pgtype.OID]bool) @@ -120,13 +120,15 @@ func init() { databaseSqlOIDs[pgtype.XIDOID] = true } +var ( + fakeTxMutex sync.Mutex + fakeTxConns map[*pgx.Conn]*sql.Tx +) + type Driver struct { configMutex sync.Mutex configCount int64 configs map[int64]*DriverConfig - - fakeTxMutex sync.Mutex - fakeTxConns map[*pgx.Conn]*sql.Tx } func (d *Driver) Open(name string) (driver.Conn, error) { @@ -571,21 +573,20 @@ func (fakeTx) Commit() error { return nil } func (fakeTx) Rollback() error { return nil } func AcquireConn(db *sql.DB) (*pgx.Conn, error) { - driver, ok := db.Driver().(*Driver) - if !ok { - return nil, ErrNotPgx - } - var conn *pgx.Conn ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } + if conn == nil { + tx.Rollback() + return nil, ErrNotPgx + } - driver.fakeTxMutex.Lock() - driver.fakeTxConns[conn] = tx - driver.fakeTxMutex.Unlock() + fakeTxMutex.Lock() + fakeTxConns[conn] = tx + fakeTxMutex.Unlock() return conn, nil } @@ -594,14 +595,13 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { var tx *sql.Tx var ok bool - driver := db.Driver().(*Driver) - driver.fakeTxMutex.Lock() - tx, ok = driver.fakeTxConns[conn] + fakeTxMutex.Lock() + tx, ok = fakeTxConns[conn] if ok { - delete(driver.fakeTxConns, conn) - driver.fakeTxMutex.Unlock() + delete(fakeTxConns, conn) + fakeTxMutex.Unlock() } else { - driver.fakeTxMutex.Unlock() + fakeTxMutex.Unlock() return errors.Errorf("can't release conn that is not acquired") }