diff --git a/stdlib/sql.go b/stdlib/sql.go index 7e635324..e70780c1 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -13,6 +13,54 @@ // if err != nil { // return err // } +// +// A DriverConfig can be used to further configure the connection process. This +// allows configuring TLS configuration, setting a custom dialer, logging, and +// setting an AfterConnect hook. +// +// driverConfig := stdlib.DriverConfig{ +// ConnConfig: ConnConfig: pgx.ConnConfig{ +// Logger: logger, +// }, +// AfterConnect: func(c *pgx.Conn) error { +// // Ensure all connections have this temp table available +// _, err := c.Exec("create temporary table foo(...)") +// return err +// }, +// } +// +// stdlib.RegisterDriverConfig(&driverConfig) +// +// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) +// if err != nil { +// return err +// } +// +// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard +// database/sql.DB connection pool. This allows operations that must be +// performed on a single connection, but should not be run in a transaction or +// to use pgx specific functionality. +// +// conn, err := stdlib.AcquireConn(db) +// if err != nil { +// return err +// } +// defer stdlib.ReleaseConn(db, conn) +// +// // do stuff with pgx.Conn +// +// It also can be used to enable a fast path for pgx while preserving +// compatibility with other drivers and database. +// +// conn, err := stdlib.AcquireConn(db) +// if err == nil { +// // fast path with pgx +// // ... +// // release conn when done +// stdlib.ReleaseConn(db, conn) +// } else { +// // normal path for other drivers and databases +// } package stdlib import ( @@ -20,6 +68,7 @@ import ( "database/sql" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" "sync" @@ -34,9 +83,16 @@ var databaseSqlOids map[pgtype.Oid]bool var pgxDriver *Driver +type ctxKey int + +var ctxKeyFakeTx ctxKey = 0 + +var ErrNotPgx = errors.New("not pgx *sql.DB") + func init() { pgxDriver = &Driver{ - configs: make(map[int64]*DriverConfig), + configs: make(map[int64]*DriverConfig), + fakeTxConns: make(map[*pgx.Conn]*sql.Tx), } sql.Register("pgx", pgxDriver) @@ -60,6 +116,9 @@ type Driver struct { configMutex sync.Mutex configCount int64 configs map[int64]*DriverConfig + + fakeTxMutex sync.Mutex + fakeTxConns map[*pgx.Conn]*sql.Tx } func (d *Driver) Open(name string) (driver.Conn, error) { @@ -91,7 +150,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) { } } - c := &Conn{conn: conn} + c := &Conn{conn: conn, driver: d} return c, nil } @@ -145,6 +204,7 @@ func UnregisterDriverConfig(c *DriverConfig) { type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names + driver *Driver } func (c *Conn) Prepare(query string) (driver.Stmt, error) { @@ -178,6 +238,11 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, driver.ErrBadConn } + if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { + *pconn = c.conn + return fakeTx{}, nil + } + var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: @@ -407,3 +472,47 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} { } return args } + +type fakeTx struct{} + +func (fakeTx) Commit() error { return nil } + +func (fakeTx) Rollback() error { return nil } + +func AcquireConn(db *sql.DB) (*pgx.Conn, error) { + driver, ok := db.Driver().(*Driver) + if !ok { + return nil, ErrNotPgx + } + + var conn *pgx.Conn + ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + driver.fakeTxMutex.Lock() + driver.fakeTxConns[conn] = tx + driver.fakeTxMutex.Unlock() + + return conn, nil +} + +func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { + var tx *sql.Tx + var ok bool + + driver := db.Driver().(*Driver) + driver.fakeTxMutex.Lock() + tx, ok = driver.fakeTxConns[conn] + if ok { + delete(driver.fakeTxConns, conn) + driver.fakeTxMutex.Unlock() + } else { + driver.fakeTxMutex.Unlock() + return fmt.Errorf("can't release conn that is not acquired") + } + + return tx.Rollback() +} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index fdc93c0a..e9fcd27b 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -685,3 +685,42 @@ func TestConnBeginTxReadOnly(t *testing.T) { ensureConnValid(t, db) } + +func TestAcquireConn(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + var conns []*pgx.Conn + + for i := 1; i < 6; i++ { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Errorf("%d. AcquireConn failed: %v", i, err) + continue + } + + var n int32 + err = conn.QueryRow("select 1").Scan(&n) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + if n != 1 { + t.Errorf("%d. n => %d, want %d", i, n, 1) + } + + stats := db.Stats() + if stats.OpenConnections != i { + t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) + } + + conns = append(conns, conn) + } + + for i, conn := range conns { + if err := stdlib.ReleaseConn(db, conn); err != nil { + t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) + } + } + + ensureConnValid(t, db) +}