mirror of https://github.com/jackc/pgx.git
Add DriverConfig system to stdlib
parent
8b6c32d13a
commit
78d344d1ab
|
@ -54,6 +54,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -72,9 +73,13 @@ var (
|
||||||
// binary, anything else will be forced to text format
|
// binary, anything else will be forced to text format
|
||||||
var databaseSqlOids map[pgtype.Oid]bool
|
var databaseSqlOids map[pgtype.Oid]bool
|
||||||
|
|
||||||
|
var pgxDriver *Driver
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
d := &Driver{}
|
pgxDriver = &Driver{
|
||||||
sql.Register("pgx", d)
|
configs: make(map[int64]*DriverConfig),
|
||||||
|
}
|
||||||
|
sql.Register("pgx", pgxDriver)
|
||||||
|
|
||||||
databaseSqlOids = make(map[pgtype.Oid]bool)
|
databaseSqlOids = make(map[pgtype.Oid]bool)
|
||||||
databaseSqlOids[pgtype.BoolOid] = true
|
databaseSqlOids[pgtype.BoolOid] = true
|
||||||
|
@ -94,6 +99,10 @@ func init() {
|
||||||
|
|
||||||
type Driver struct {
|
type Driver struct {
|
||||||
Pool *pgx.ConnPool
|
Pool *pgx.ConnPool
|
||||||
|
|
||||||
|
configMutex sync.Mutex
|
||||||
|
configCount int64
|
||||||
|
configs map[int64]*DriverConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
|
@ -106,20 +115,85 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
return &Conn{conn: conn, pool: d.Pool}, nil
|
return &Conn{conn: conn, pool: d.Pool}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
connConfig, err := pgx.ParseConnectionString(name)
|
var connConfig pgx.ConnConfig
|
||||||
|
var afterConnect func(*pgx.Conn) error
|
||||||
|
if len(name) >= 9 && name[0] == 0 {
|
||||||
|
idBuf := []byte(name)[1:9]
|
||||||
|
id := int64(binary.BigEndian.Uint64(idBuf))
|
||||||
|
connConfig = d.configs[id].ConnConfig
|
||||||
|
afterConnect = d.configs[id].AfterConnect
|
||||||
|
name = name[9:]
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedConfig, err := pgx.ParseConnectionString(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
connConfig = connConfig.Merge(parsedConfig)
|
||||||
|
|
||||||
conn, err := pgx.Connect(connConfig)
|
conn, err := pgx.Connect(connConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if afterConnect != nil {
|
||||||
|
err = afterConnect(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c := &Conn{conn: conn}
|
c := &Conn{conn: conn}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DriverConfig struct {
|
||||||
|
pgx.ConnConfig
|
||||||
|
AfterConnect func(*pgx.Conn) error // function to call on every new connection
|
||||||
|
driver *Driver
|
||||||
|
id int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionString encodes the DriverConfig into the original connection
|
||||||
|
// string. DriverConfig must be registered before calling ConnectionString.
|
||||||
|
func (c *DriverConfig) ConnectionString(original string) string {
|
||||||
|
if c.driver == nil {
|
||||||
|
panic("DriverConfig must be registered before calling ConnectionString")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 9)
|
||||||
|
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
|
||||||
|
buf = append(buf, original...)
|
||||||
|
return string(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) registerDriverConfig(c *DriverConfig) {
|
||||||
|
d.configMutex.Lock()
|
||||||
|
|
||||||
|
c.driver = d
|
||||||
|
c.id = d.configCount
|
||||||
|
d.configs[d.configCount] = c
|
||||||
|
d.configCount++
|
||||||
|
|
||||||
|
d.configMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
|
||||||
|
d.configMutex.Lock()
|
||||||
|
delete(d.configs, c.id)
|
||||||
|
d.configMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterDriverConfig registers a DriverConfig for use with Open.
|
||||||
|
func RegisterDriverConfig(c *DriverConfig) {
|
||||||
|
pgxDriver.registerDriverConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterDriverConfig removes a DriverConfig registration.
|
||||||
|
func UnregisterDriverConfig(c *DriverConfig) {
|
||||||
|
pgxDriver.unregisterDriverConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB
|
// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB
|
||||||
// with pool as the backend. This enables full control over the connection
|
// with pool as the backend. This enables full control over the connection
|
||||||
// process and configuration while maintaining compatibility with the
|
// process and configuration while maintaining compatibility with the
|
||||||
|
|
|
@ -202,6 +202,32 @@ func TestOpenFromConnPoolRace(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenWithDriverConfigAfterConnect(t *testing.T) {
|
||||||
|
driverConfig := stdlib.DriverConfig{
|
||||||
|
AfterConnect: func(c *pgx.Conn) error {
|
||||||
|
_, err := c.Exec("create temporary sequence pgx")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stdlib.RegisterDriverConfig(&driverConfig)
|
||||||
|
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||||
|
|
||||||
|
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow("select nextval('pgx')").Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("n => %d, want %d", n, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStmtExec(t *testing.T) {
|
func TestStmtExec(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
Loading…
Reference in New Issue