diff --git a/stdlib/sql.go b/stdlib/sql.go index 1218d5f2..7e1d64fc 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -14,27 +14,14 @@ // return err // } // -// A DriverConfig can be used to further configure the connection process. This -// allows configuring TLS configuration, setting a custom dialer, logging, and -// setting an AfterConnect hook. +// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the +// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used +// with sql.Open. // -// driverConfig := stdlib.DriverConfig{ -// ConnConfig: pgx.ConnConfig{ -// Logger: logger, -// }, -// AfterConnect: func(c *pgx.Conn) error { -// // Ensure all connections have this temp table available -// _, err := c.Exec("create temporary table foo(...)") -// return err -// }, -// } -// -// stdlib.RegisterDriverConfig(&driverConfig) -// -// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) -// if err != nil { -// return err -// } +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Logger = myLogger +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) // // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. // It does not support named parameters. @@ -99,7 +86,9 @@ var ctxKeyFakeTx ctxKey = 0 var ErrNotPgx = errors.New("not pgx *sql.DB") func init() { - pgxDriver = &Driver{} + pgxDriver = &Driver{ + configs: make(map[string]*pgx.ConnConfig), + } fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) @@ -125,12 +114,24 @@ var ( fakeTxConns map[*pgx.Conn]*sql.Tx ) -type Driver struct{} +type Driver struct { + configMutex sync.Mutex + configs map[string]*pgx.ConnConfig +} func (d *Driver) Open(name string) (driver.Conn, error) { - connConfig, err := pgx.ParseConfig(name) - if err != nil { - return nil, err + var connConfig *pgx.ConnConfig + + d.configMutex.Lock() + connConfig = d.configs[name] + d.configMutex.Unlock() + + if connConfig == nil { + var err error + connConfig, err = pgx.ParseConfig(name) + if err != nil { + return nil, err + } } ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout @@ -144,6 +145,30 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return c, nil } +func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string { + d.configMutex.Lock() + connStr := fmt.Sprintf("registeredConnConfig%d", len(d.configs)) + d.configs[connStr] = c + d.configMutex.Unlock() + return connStr +} + +func (d *Driver) unregisterConnConfig(connStr string) { + d.configMutex.Lock() + delete(d.configs, connStr) + d.configMutex.Unlock() +} + +// RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open. +func RegisterConnConfig(c *pgx.ConnConfig) string { + return pgxDriver.registerConnConfig(c) +} + +// UnregisterConnConfig removes the ConnConfig registration for connStr. +func UnregisterConnConfig(connStr string) { + pgxDriver.unregisterConnConfig(connStr) +} + type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 4d188ad9..99842e7f 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -15,6 +15,8 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func closeDB(t *testing.T, db *sql.DB) { @@ -288,20 +290,6 @@ func TestConnQuery(t *testing.T) { ensureConnValid(t, db) } -type testLog struct { - lvl pgx.LogLevel - msg string - data map[string]interface{} -} - -type testLogger struct { - logs []testLog -} - -func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) { - l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) -} - func TestConnQueryNull(t *testing.T) { db := openDB(t) defer closeDB(t, db) @@ -1132,3 +1120,40 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) { ensureConnValid(t, db) } + +type testLog struct { + lvl pgx.LogLevel + msg string + data map[string]interface{} +} + +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) { + l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) +} + +func TestRegisterConnConfig(t *testing.T) { + connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + logger := &testLogger{} + connConfig.Logger = logger + + connStr := stdlib.RegisterConnConfig(connConfig) + defer stdlib.UnregisterConnConfig(connStr) + + db, err := sql.Open("pgx", connStr) + require.NoError(t, err) + defer closeDB(t, db) + + var n int64 + err = db.QueryRow("select 1").Scan(&n) + require.NoError(t, err) + + l := logger.logs[len(logger.logs)-1] + assert.Equal(t, "Query", l.msg) + assert.Equal(t, "select 1", l.data["sql"]) +}