Add RegisterConnConfig to stdlib

This restored functionality lost in the v3 to v4 transition when
RegisterDriverConfig was removed.

fixes #617
pull/646/head
Jack Christensen 2019-11-16 11:06:57 -06:00
parent f3a3ee1a0e
commit 69e9c33daf
2 changed files with 89 additions and 39 deletions

View File

@ -14,27 +14,14 @@
// return err
// }
//
// A DriverConfig can be used to further configure the connection process. This
// allows configuring TLS configuration, setting a custom dialer, logging, and
// setting an AfterConnect hook.
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
// with sql.Open.
//
// driverConfig := stdlib.DriverConfig{
// ConnConfig: pgx.ConnConfig{
// Logger: logger,
// },
// AfterConnect: func(c *pgx.Conn) error {
// // Ensure all connections have this temp table available
// _, err := c.Exec("create temporary table foo(...)")
// return err
// },
// }
//
// stdlib.RegisterDriverConfig(&driverConfig)
//
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
// if err != nil {
// return err
// }
// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
// connConfig.Logger = myLogger
// connStr := stdlib.RegisterConnConfig(connConfig)
// db, _ := sql.Open("pgx", connStr)
//
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
// It does not support named parameters.
@ -99,7 +86,9 @@ var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() {
pgxDriver = &Driver{}
pgxDriver = &Driver{
configs: make(map[string]*pgx.ConnConfig),
}
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
@ -125,12 +114,24 @@ var (
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type Driver struct{}
type Driver struct {
configMutex sync.Mutex
configs map[string]*pgx.ConnConfig
}
func (d *Driver) Open(name string) (driver.Conn, error) {
connConfig, err := pgx.ParseConfig(name)
if err != nil {
return nil, err
var connConfig *pgx.ConnConfig
d.configMutex.Lock()
connConfig = d.configs[name]
d.configMutex.Unlock()
if connConfig == nil {
var err error
connConfig, err = pgx.ParseConfig(name)
if err != nil {
return nil, err
}
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
@ -144,6 +145,30 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
return c, nil
}
func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string {
d.configMutex.Lock()
connStr := fmt.Sprintf("registeredConnConfig%d", len(d.configs))
d.configs[connStr] = c
d.configMutex.Unlock()
return connStr
}
func (d *Driver) unregisterConnConfig(connStr string) {
d.configMutex.Lock()
delete(d.configs, connStr)
d.configMutex.Unlock()
}
// RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open.
func RegisterConnConfig(c *pgx.ConnConfig) string {
return pgxDriver.registerConnConfig(c)
}
// UnregisterConnConfig removes the ConnConfig registration for connStr.
func UnregisterConnConfig(connStr string) {
pgxDriver.unregisterConnConfig(connStr)
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names

View File

@ -15,6 +15,8 @@ import (
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func closeDB(t *testing.T, db *sql.DB) {
@ -288,20 +290,6 @@ func TestConnQuery(t *testing.T) {
ensureConnValid(t, db)
}
type testLog struct {
lvl pgx.LogLevel
msg string
data map[string]interface{}
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) {
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
func TestConnQueryNull(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
@ -1132,3 +1120,40 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) {
ensureConnValid(t, db)
}
type testLog struct {
lvl pgx.LogLevel
msg string
data map[string]interface{}
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) {
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
func TestRegisterConnConfig(t *testing.T) {
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
logger := &testLogger{}
connConfig.Logger = logger
connStr := stdlib.RegisterConnConfig(connConfig)
defer stdlib.UnregisterConnConfig(connStr)
db, err := sql.Open("pgx", connStr)
require.NoError(t, err)
defer closeDB(t, db)
var n int64
err = db.QueryRow("select 1").Scan(&n)
require.NoError(t, err)
l := logger.logs[len(logger.logs)-1]
assert.Equal(t, "Query", l.msg)
assert.Equal(t, "select 1", l.data["sql"])
}