Add BeforeConnect callback to pgxpool.Config.

This allows for connection settings to be updated without having to create
a new pool. The callback is passed a copy of the pgx.ConnConfig and will
not impact existing live connections.
pull/901/head
Robert Froehlich 2021-01-02 15:08:59 -08:00
parent b664891853
commit 210a217818
2 changed files with 37 additions and 1 deletions

View File

@ -70,6 +70,7 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows {
type Pool struct {
p *puddle.Pool
config *Config
beforeConnect func(context.Context, *pgx.ConnConfig) error
afterConnect func(context.Context, *pgx.Conn) error
beforeAcquire func(context.Context, *pgx.Conn) bool
afterRelease func(*pgx.Conn) bool
@ -85,6 +86,10 @@ type Pool struct {
type Config struct {
ConnConfig *pgx.ConnConfig
// BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and
// will not impact any existing open connections.
BeforeConnect func(context.Context, *pgx.ConnConfig) error
// AfterConnect is called after a connection is established, but before it is added to the pool.
AfterConnect func(context.Context, *pgx.Conn) error
@ -155,6 +160,7 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
p := &Pool{
config: config,
beforeConnect: config.BeforeConnect,
afterConnect: config.AfterConnect,
beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
@ -167,7 +173,16 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
p.p = puddle.NewPool(
func(ctx context.Context) (interface{}, error) {
conn, err := pgx.ConnectConfig(ctx, config.ConnConfig)
connConfig := p.config.ConnConfig
if p.beforeConnect != nil {
connConfig = p.config.ConnConfig.Copy()
if err := p.beforeConnect(ctx, connConfig); err != nil {
return nil, err
}
}
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err != nil {
return nil, err
}

View File

@ -112,6 +112,27 @@ func TestPoolAcquireAndConnRelease(t *testing.T) {
c.Release()
}
func TestPoolBeforeConnect(t *testing.T) {
t.Parallel()
config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error {
cfg.Config.RuntimeParams["application_name"] = "pgx"
return nil
}
db, err := pgxpool.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer db.Close()
var str string
err = db.QueryRow(context.Background(), "SHOW application_name").Scan(&str)
require.NoError(t, err)
assert.EqualValues(t, "pgx", str)
}
func TestPoolAfterConnect(t *testing.T) {
t.Parallel()