From 78d344d1abebb939f5ac4cc9a88c4e072f0efddd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 15:28:16 -0500 Subject: [PATCH] Add DriverConfig system to stdlib --- stdlib/sql.go | 80 ++++++++++++++++++++++++++++++++++++++++++++-- stdlib/sql_test.go | 26 +++++++++++++++ 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 80a559af..19d96260 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -54,6 +54,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/binary" "errors" "fmt" "io" @@ -72,9 +73,13 @@ var ( // binary, anything else will be forced to text format var databaseSqlOids map[pgtype.Oid]bool +var pgxDriver *Driver + func init() { - d := &Driver{} - sql.Register("pgx", d) + pgxDriver = &Driver{ + configs: make(map[int64]*DriverConfig), + } + sql.Register("pgx", pgxDriver) databaseSqlOids = make(map[pgtype.Oid]bool) databaseSqlOids[pgtype.BoolOid] = true @@ -94,6 +99,10 @@ func init() { type Driver struct { Pool *pgx.ConnPool + + configMutex sync.Mutex + configCount int64 + configs map[int64]*DriverConfig } func (d *Driver) Open(name string) (driver.Conn, error) { @@ -106,20 +115,85 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return &Conn{conn: conn, pool: d.Pool}, nil } - connConfig, err := pgx.ParseConnectionString(name) + var connConfig pgx.ConnConfig + var afterConnect func(*pgx.Conn) error + if len(name) >= 9 && name[0] == 0 { + idBuf := []byte(name)[1:9] + id := int64(binary.BigEndian.Uint64(idBuf)) + connConfig = d.configs[id].ConnConfig + afterConnect = d.configs[id].AfterConnect + name = name[9:] + } + + parsedConfig, err := pgx.ParseConnectionString(name) if err != nil { return nil, err } + connConfig = connConfig.Merge(parsedConfig) conn, err := pgx.Connect(connConfig) if err != nil { return nil, err } + if afterConnect != nil { + err = afterConnect(conn) + if err != nil { + return nil, err + } + } + c := &Conn{conn: conn} return c, nil } +type DriverConfig struct { + pgx.ConnConfig + AfterConnect func(*pgx.Conn) error // function to call on every new connection + driver *Driver + id int64 +} + +// ConnectionString encodes the DriverConfig into the original connection +// string. DriverConfig must be registered before calling ConnectionString. +func (c *DriverConfig) ConnectionString(original string) string { + if c.driver == nil { + panic("DriverConfig must be registered before calling ConnectionString") + } + + buf := make([]byte, 9) + binary.BigEndian.PutUint64(buf[1:], uint64(c.id)) + buf = append(buf, original...) + return string(buf) +} + +func (d *Driver) registerDriverConfig(c *DriverConfig) { + d.configMutex.Lock() + + c.driver = d + c.id = d.configCount + d.configs[d.configCount] = c + d.configCount++ + + d.configMutex.Unlock() +} + +func (d *Driver) unregisterDriverConfig(c *DriverConfig) { + d.configMutex.Lock() + delete(d.configs, c.id) + d.configMutex.Unlock() +} + +// RegisterDriverConfig registers a DriverConfig for use with Open. +func RegisterDriverConfig(c *DriverConfig) { + pgxDriver.registerDriverConfig(c) +} + +// UnregisterDriverConfig removes a DriverConfig registration. +func UnregisterDriverConfig(c *DriverConfig) { + pgxDriver.unregisterDriverConfig(c) +} + // OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB // with pool as the backend. This enables full control over the connection // process and configuration while maintaining compatibility with the diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index ba74560d..e4fbcb0c 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -202,6 +202,32 @@ func TestOpenFromConnPoolRace(t *testing.T) { wg.Wait() } +func TestOpenWithDriverConfigAfterConnect(t *testing.T) { + driverConfig := stdlib.DriverConfig{ + AfterConnect: func(c *pgx.Conn) error { + _, err := c.Exec("create temporary sequence pgx") + return err + }, + } + + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) + + db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + + var n int64 + err = db.QueryRow("select nextval('pgx')").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("n => %d, want %d", n, 1) + } +} + func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db)