Add DriverConfig system to stdlib

batch-wip
Jack Christensen 2017-05-06 15:28:16 -05:00
parent 8b6c32d13a
commit 78d344d1ab
2 changed files with 103 additions and 3 deletions

View File

@ -54,6 +54,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -72,9 +73,13 @@ var (
// binary, anything else will be forced to text format // binary, anything else will be forced to text format
var databaseSqlOids map[pgtype.Oid]bool var databaseSqlOids map[pgtype.Oid]bool
var pgxDriver *Driver
func init() { func init() {
d := &Driver{} pgxDriver = &Driver{
sql.Register("pgx", d) configs: make(map[int64]*DriverConfig),
}
sql.Register("pgx", pgxDriver)
databaseSqlOids = make(map[pgtype.Oid]bool) databaseSqlOids = make(map[pgtype.Oid]bool)
databaseSqlOids[pgtype.BoolOid] = true databaseSqlOids[pgtype.BoolOid] = true
@ -94,6 +99,10 @@ func init() {
type Driver struct { type Driver struct {
Pool *pgx.ConnPool Pool *pgx.ConnPool
configMutex sync.Mutex
configCount int64
configs map[int64]*DriverConfig
} }
func (d *Driver) Open(name string) (driver.Conn, error) { func (d *Driver) Open(name string) (driver.Conn, error) {
@ -106,20 +115,85 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
return &Conn{conn: conn, pool: d.Pool}, nil return &Conn{conn: conn, pool: d.Pool}, nil
} }
connConfig, err := pgx.ParseConnectionString(name) var connConfig pgx.ConnConfig
var afterConnect func(*pgx.Conn) error
if len(name) >= 9 && name[0] == 0 {
idBuf := []byte(name)[1:9]
id := int64(binary.BigEndian.Uint64(idBuf))
connConfig = d.configs[id].ConnConfig
afterConnect = d.configs[id].AfterConnect
name = name[9:]
}
parsedConfig, err := pgx.ParseConnectionString(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
connConfig = connConfig.Merge(parsedConfig)
conn, err := pgx.Connect(connConfig) conn, err := pgx.Connect(connConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if afterConnect != nil {
err = afterConnect(conn)
if err != nil {
return nil, err
}
}
c := &Conn{conn: conn} c := &Conn{conn: conn}
return c, nil return c, nil
} }
type DriverConfig struct {
pgx.ConnConfig
AfterConnect func(*pgx.Conn) error // function to call on every new connection
driver *Driver
id int64
}
// ConnectionString encodes the DriverConfig into the original connection
// string. DriverConfig must be registered before calling ConnectionString.
func (c *DriverConfig) ConnectionString(original string) string {
if c.driver == nil {
panic("DriverConfig must be registered before calling ConnectionString")
}
buf := make([]byte, 9)
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
buf = append(buf, original...)
return string(buf)
}
func (d *Driver) registerDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
c.driver = d
c.id = d.configCount
d.configs[d.configCount] = c
d.configCount++
d.configMutex.Unlock()
}
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
delete(d.configs, c.id)
d.configMutex.Unlock()
}
// RegisterDriverConfig registers a DriverConfig for use with Open.
func RegisterDriverConfig(c *DriverConfig) {
pgxDriver.registerDriverConfig(c)
}
// UnregisterDriverConfig removes a DriverConfig registration.
func UnregisterDriverConfig(c *DriverConfig) {
pgxDriver.unregisterDriverConfig(c)
}
// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB // OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB
// with pool as the backend. This enables full control over the connection // with pool as the backend. This enables full control over the connection
// process and configuration while maintaining compatibility with the // process and configuration while maintaining compatibility with the

View File

@ -202,6 +202,32 @@ func TestOpenFromConnPoolRace(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestOpenWithDriverConfigAfterConnect(t *testing.T) {
driverConfig := stdlib.DriverConfig{
AfterConnect: func(c *pgx.Conn) error {
_, err := c.Exec("create temporary sequence pgx")
return err
},
}
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
var n int64
err = db.QueryRow("select nextval('pgx')").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("n => %d, want %d", n, 1)
}
}
func TestStmtExec(t *testing.T) { func TestStmtExec(t *testing.T) {
db := openDB(t) db := openDB(t)
defer closeDB(t, db) defer closeDB(t, db)