Add stdlib AcquireConn and ReleaseConn

Also add some documentation.
batch-wip
Jack Christensen 2017-05-06 19:39:40 -05:00
parent 4cbefbb27e
commit c78d450c19
2 changed files with 150 additions and 2 deletions

View File

@ -13,6 +13,54 @@
// if err != nil {
// return err
// }
//
// A DriverConfig can be used to further configure the connection process. This
// allows configuring TLS configuration, setting a custom dialer, logging, and
// setting an AfterConnect hook.
//
// driverConfig := stdlib.DriverConfig{
// ConnConfig: ConnConfig: pgx.ConnConfig{
// Logger: logger,
// },
// AfterConnect: func(c *pgx.Conn) error {
// // Ensure all connections have this temp table available
// _, err := c.Exec("create temporary table foo(...)")
// return err
// },
// }
//
// stdlib.RegisterDriverConfig(&driverConfig)
//
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
// if err != nil {
// return err
// }
//
// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard
// database/sql.DB connection pool. This allows operations that must be
// performed on a single connection, but should not be run in a transaction or
// to use pgx specific functionality.
//
// conn, err := stdlib.AcquireConn(db)
// if err != nil {
// return err
// }
// defer stdlib.ReleaseConn(db, conn)
//
// // do stuff with pgx.Conn
//
// It also can be used to enable a fast path for pgx while preserving
// compatibility with other drivers and database.
//
// conn, err := stdlib.AcquireConn(db)
// if err == nil {
// // fast path with pgx
// // ...
// // release conn when done
// stdlib.ReleaseConn(db, conn)
// } else {
// // normal path for other drivers and databases
// }
package stdlib
import (
@ -20,6 +68,7 @@ import (
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
@ -34,9 +83,16 @@ var databaseSqlOids map[pgtype.Oid]bool
var pgxDriver *Driver
type ctxKey int
var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() {
pgxDriver = &Driver{
configs: make(map[int64]*DriverConfig),
configs: make(map[int64]*DriverConfig),
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
}
sql.Register("pgx", pgxDriver)
@ -60,6 +116,9 @@ type Driver struct {
configMutex sync.Mutex
configCount int64
configs map[int64]*DriverConfig
fakeTxMutex sync.Mutex
fakeTxConns map[*pgx.Conn]*sql.Tx
}
func (d *Driver) Open(name string) (driver.Conn, error) {
@ -91,7 +150,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
}
}
c := &Conn{conn: conn}
c := &Conn{conn: conn, driver: d}
return c, nil
}
@ -145,6 +204,7 @@ func UnregisterDriverConfig(c *DriverConfig) {
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
}
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
@ -178,6 +238,11 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
return nil, driver.ErrBadConn
}
if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok {
*pconn = c.conn
return fakeTx{}, nil
}
var pgxOpts pgx.TxOptions
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
@ -407,3 +472,47 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
}
return args
}
type fakeTx struct{}
func (fakeTx) Commit() error { return nil }
func (fakeTx) Rollback() error { return nil }
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
driver, ok := db.Driver().(*Driver)
if !ok {
return nil, ErrNotPgx
}
var conn *pgx.Conn
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
driver.fakeTxMutex.Lock()
driver.fakeTxConns[conn] = tx
driver.fakeTxMutex.Unlock()
return conn, nil
}
func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
var tx *sql.Tx
var ok bool
driver := db.Driver().(*Driver)
driver.fakeTxMutex.Lock()
tx, ok = driver.fakeTxConns[conn]
if ok {
delete(driver.fakeTxConns, conn)
driver.fakeTxMutex.Unlock()
} else {
driver.fakeTxMutex.Unlock()
return fmt.Errorf("can't release conn that is not acquired")
}
return tx.Rollback()
}

View File

@ -685,3 +685,42 @@ func TestConnBeginTxReadOnly(t *testing.T) {
ensureConnValid(t, db)
}
func TestAcquireConn(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
var conns []*pgx.Conn
for i := 1; i < 6; i++ {
conn, err := stdlib.AcquireConn(db)
if err != nil {
t.Errorf("%d. AcquireConn failed: %v", i, err)
continue
}
var n int32
err = conn.QueryRow("select 1").Scan(&n)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", i, err)
}
if n != 1 {
t.Errorf("%d. n => %d, want %d", i, n, 1)
}
stats := db.Stats()
if stats.OpenConnections != i {
t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
}
conns = append(conns, conn)
}
for i, conn := range conns {
if err := stdlib.ReleaseConn(db, conn); err != nil {
t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
}
}
ensureConnValid(t, db)
}