mirror of https://github.com/jackc/pgx.git
Add DriverConfig system to stdlib
parent
8b6c32d13a
commit
78d344d1ab
|
@ -54,6 +54,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -72,9 +73,13 @@ var (
|
|||
// binary, anything else will be forced to text format
|
||||
var databaseSqlOids map[pgtype.Oid]bool
|
||||
|
||||
var pgxDriver *Driver
|
||||
|
||||
func init() {
|
||||
d := &Driver{}
|
||||
sql.Register("pgx", d)
|
||||
pgxDriver = &Driver{
|
||||
configs: make(map[int64]*DriverConfig),
|
||||
}
|
||||
sql.Register("pgx", pgxDriver)
|
||||
|
||||
databaseSqlOids = make(map[pgtype.Oid]bool)
|
||||
databaseSqlOids[pgtype.BoolOid] = true
|
||||
|
@ -94,6 +99,10 @@ func init() {
|
|||
|
||||
type Driver struct {
|
||||
Pool *pgx.ConnPool
|
||||
|
||||
configMutex sync.Mutex
|
||||
configCount int64
|
||||
configs map[int64]*DriverConfig
|
||||
}
|
||||
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
|
@ -106,20 +115,85 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
|||
return &Conn{conn: conn, pool: d.Pool}, nil
|
||||
}
|
||||
|
||||
connConfig, err := pgx.ParseConnectionString(name)
|
||||
var connConfig pgx.ConnConfig
|
||||
var afterConnect func(*pgx.Conn) error
|
||||
if len(name) >= 9 && name[0] == 0 {
|
||||
idBuf := []byte(name)[1:9]
|
||||
id := int64(binary.BigEndian.Uint64(idBuf))
|
||||
connConfig = d.configs[id].ConnConfig
|
||||
afterConnect = d.configs[id].AfterConnect
|
||||
name = name[9:]
|
||||
}
|
||||
|
||||
parsedConfig, err := pgx.ParseConnectionString(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connConfig = connConfig.Merge(parsedConfig)
|
||||
|
||||
conn, err := pgx.Connect(connConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if afterConnect != nil {
|
||||
err = afterConnect(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c := &Conn{conn: conn}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type DriverConfig struct {
|
||||
pgx.ConnConfig
|
||||
AfterConnect func(*pgx.Conn) error // function to call on every new connection
|
||||
driver *Driver
|
||||
id int64
|
||||
}
|
||||
|
||||
// ConnectionString encodes the DriverConfig into the original connection
|
||||
// string. DriverConfig must be registered before calling ConnectionString.
|
||||
func (c *DriverConfig) ConnectionString(original string) string {
|
||||
if c.driver == nil {
|
||||
panic("DriverConfig must be registered before calling ConnectionString")
|
||||
}
|
||||
|
||||
buf := make([]byte, 9)
|
||||
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
|
||||
buf = append(buf, original...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (d *Driver) registerDriverConfig(c *DriverConfig) {
|
||||
d.configMutex.Lock()
|
||||
|
||||
c.driver = d
|
||||
c.id = d.configCount
|
||||
d.configs[d.configCount] = c
|
||||
d.configCount++
|
||||
|
||||
d.configMutex.Unlock()
|
||||
}
|
||||
|
||||
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
|
||||
d.configMutex.Lock()
|
||||
delete(d.configs, c.id)
|
||||
d.configMutex.Unlock()
|
||||
}
|
||||
|
||||
// RegisterDriverConfig registers a DriverConfig for use with Open.
|
||||
func RegisterDriverConfig(c *DriverConfig) {
|
||||
pgxDriver.registerDriverConfig(c)
|
||||
}
|
||||
|
||||
// UnregisterDriverConfig removes a DriverConfig registration.
|
||||
func UnregisterDriverConfig(c *DriverConfig) {
|
||||
pgxDriver.unregisterDriverConfig(c)
|
||||
}
|
||||
|
||||
// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB
|
||||
// with pool as the backend. This enables full control over the connection
|
||||
// process and configuration while maintaining compatibility with the
|
||||
|
|
|
@ -202,6 +202,32 @@ func TestOpenFromConnPoolRace(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestOpenWithDriverConfigAfterConnect(t *testing.T) {
|
||||
driverConfig := stdlib.DriverConfig{
|
||||
AfterConnect: func(c *pgx.Conn) error {
|
||||
_, err := c.Exec("create temporary sequence pgx")
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
stdlib.RegisterDriverConfig(&driverConfig)
|
||||
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow("select nextval('pgx')").Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("n => %d, want %d", n, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmtExec(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
|
Loading…
Reference in New Issue