mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Add RegisterConnConfig to stdlib
This restored functionality lost in the v3 to v4 transition when RegisterDriverConfig was removed. fixes #617
This commit is contained in:
parent
f3a3ee1a0e
commit
69e9c33daf
@ -14,27 +14,14 @@
|
|||||||
// return err
|
// return err
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// A DriverConfig can be used to further configure the connection process. This
|
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
|
||||||
// allows configuring TLS configuration, setting a custom dialer, logging, and
|
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
|
||||||
// setting an AfterConnect hook.
|
// with sql.Open.
|
||||||
//
|
//
|
||||||
// driverConfig := stdlib.DriverConfig{
|
// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
|
||||||
// ConnConfig: pgx.ConnConfig{
|
// connConfig.Logger = myLogger
|
||||||
// Logger: logger,
|
// connStr := stdlib.RegisterConnConfig(connConfig)
|
||||||
// },
|
// db, _ := sql.Open("pgx", connStr)
|
||||||
// AfterConnect: func(c *pgx.Conn) error {
|
|
||||||
// // Ensure all connections have this temp table available
|
|
||||||
// _, err := c.Exec("create temporary table foo(...)")
|
|
||||||
// return err
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// stdlib.RegisterDriverConfig(&driverConfig)
|
|
||||||
//
|
|
||||||
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
//
|
//
|
||||||
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
|
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
|
||||||
// It does not support named parameters.
|
// It does not support named parameters.
|
||||||
@ -99,7 +86,9 @@ var ctxKeyFakeTx ctxKey = 0
|
|||||||
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
pgxDriver = &Driver{}
|
pgxDriver = &Driver{
|
||||||
|
configs: make(map[string]*pgx.ConnConfig),
|
||||||
|
}
|
||||||
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
||||||
sql.Register("pgx", pgxDriver)
|
sql.Register("pgx", pgxDriver)
|
||||||
|
|
||||||
@ -125,12 +114,24 @@ var (
|
|||||||
fakeTxConns map[*pgx.Conn]*sql.Tx
|
fakeTxConns map[*pgx.Conn]*sql.Tx
|
||||||
)
|
)
|
||||||
|
|
||||||
type Driver struct{}
|
type Driver struct {
|
||||||
|
configMutex sync.Mutex
|
||||||
|
configs map[string]*pgx.ConnConfig
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
connConfig, err := pgx.ParseConfig(name)
|
var connConfig *pgx.ConnConfig
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
d.configMutex.Lock()
|
||||||
|
connConfig = d.configs[name]
|
||||||
|
d.configMutex.Unlock()
|
||||||
|
|
||||||
|
if connConfig == nil {
|
||||||
|
var err error
|
||||||
|
connConfig, err = pgx.ParseConfig(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
|
||||||
@ -144,6 +145,30 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string {
|
||||||
|
d.configMutex.Lock()
|
||||||
|
connStr := fmt.Sprintf("registeredConnConfig%d", len(d.configs))
|
||||||
|
d.configs[connStr] = c
|
||||||
|
d.configMutex.Unlock()
|
||||||
|
return connStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) unregisterConnConfig(connStr string) {
|
||||||
|
d.configMutex.Lock()
|
||||||
|
delete(d.configs, connStr)
|
||||||
|
d.configMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open.
|
||||||
|
func RegisterConnConfig(c *pgx.ConnConfig) string {
|
||||||
|
return pgxDriver.registerConnConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterConnConfig removes the ConnConfig registration for connStr.
|
||||||
|
func UnregisterConnConfig(connStr string) {
|
||||||
|
pgxDriver.unregisterConnConfig(connStr)
|
||||||
|
}
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn *pgx.Conn
|
conn *pgx.Conn
|
||||||
psCount int64 // Counter used for creating unique prepared statement names
|
psCount int64 // Counter used for creating unique prepared statement names
|
||||||
|
@ -15,6 +15,8 @@ import (
|
|||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
"github.com/jackc/pgx/v4/stdlib"
|
"github.com/jackc/pgx/v4/stdlib"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func closeDB(t *testing.T, db *sql.DB) {
|
func closeDB(t *testing.T, db *sql.DB) {
|
||||||
@ -288,20 +290,6 @@ func TestConnQuery(t *testing.T) {
|
|||||||
ensureConnValid(t, db)
|
ensureConnValid(t, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testLog struct {
|
|
||||||
lvl pgx.LogLevel
|
|
||||||
msg string
|
|
||||||
data map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testLogger struct {
|
|
||||||
logs []testLog
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnQueryNull(t *testing.T) {
|
func TestConnQueryNull(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
@ -1132,3 +1120,40 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
|||||||
|
|
||||||
ensureConnValid(t, db)
|
ensureConnValid(t, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testLog struct {
|
||||||
|
lvl pgx.LogLevel
|
||||||
|
msg string
|
||||||
|
data map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testLogger struct {
|
||||||
|
logs []testLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||||
|
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterConnConfig(t *testing.T) {
|
||||||
|
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logger := &testLogger{}
|
||||||
|
connConfig.Logger = logger
|
||||||
|
|
||||||
|
connStr := stdlib.RegisterConnConfig(connConfig)
|
||||||
|
defer stdlib.UnregisterConnConfig(connStr)
|
||||||
|
|
||||||
|
db, err := sql.Open("pgx", connStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow("select 1").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
l := logger.logs[len(logger.logs)-1]
|
||||||
|
assert.Equal(t, "Query", l.msg)
|
||||||
|
assert.Equal(t, "select 1", l.data["sql"])
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user