diff --git a/stdlib/sql.go b/stdlib/sql.go index 2d4930ee..76a511b1 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -99,9 +99,9 @@ var ErrNotPgx = errors.New("not pgx *sql.DB") func init() { pgxDriver = &Driver{ - configs: make(map[int64]*DriverConfig), - fakeTxConns: make(map[*pgx.Conn]*sql.Tx), + configs: make(map[int64]*DriverConfig), } + fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) databaseSqlOIDs = make(map[pgtype.OID]bool) @@ -120,13 +120,15 @@ func init() { databaseSqlOIDs[pgtype.XIDOID] = true } +var ( + fakeTxMutex sync.Mutex + fakeTxConns map[*pgx.Conn]*sql.Tx +) + type Driver struct { configMutex sync.Mutex configCount int64 configs map[int64]*DriverConfig - - fakeTxMutex sync.Mutex - fakeTxConns map[*pgx.Conn]*sql.Tx } func (d *Driver) Open(name string) (driver.Conn, error) { @@ -571,21 +573,20 @@ func (fakeTx) Commit() error { return nil } func (fakeTx) Rollback() error { return nil } func AcquireConn(db *sql.DB) (*pgx.Conn, error) { - driver, ok := db.Driver().(*Driver) - if !ok { - return nil, ErrNotPgx - } - var conn *pgx.Conn ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } + if conn == nil { + tx.Rollback() + return nil, ErrNotPgx + } - driver.fakeTxMutex.Lock() - driver.fakeTxConns[conn] = tx - driver.fakeTxMutex.Unlock() + fakeTxMutex.Lock() + fakeTxConns[conn] = tx + fakeTxMutex.Unlock() return conn, nil } @@ -594,14 +595,13 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { var tx *sql.Tx var ok bool - driver := db.Driver().(*Driver) - driver.fakeTxMutex.Lock() - tx, ok = driver.fakeTxConns[conn] + fakeTxMutex.Lock() + tx, ok = fakeTxConns[conn] if ok { - delete(driver.fakeTxConns, conn) - driver.fakeTxMutex.Unlock() + delete(fakeTxConns, conn) + fakeTxMutex.Unlock() } else { - driver.fakeTxMutex.Unlock() + fakeTxMutex.Unlock() return errors.Errorf("can't release conn that is not acquired") }