Add AfterConnect hook to pool

pull/483/head
Jack Christensen 2019-04-27 16:09:23 -05:00
parent 9008387300
commit 7558b8d05f
2 changed files with 42 additions and 1 deletions

View File

@ -18,6 +18,7 @@ var defaultHealthCheckPeriod = time.Minute
type Pool struct {
p *puddle.Pool
afterConnect func(context.Context, *pgx.Conn) error
beforeAcquire func(*pgx.Conn) bool
afterRelease func(*pgx.Conn) bool
maxConnLifetime time.Duration
@ -30,6 +31,9 @@ type Pool struct {
type Config struct {
ConnConfig *pgx.ConnConfig
// AfterConnect is called after a connection is established, but before it is added to the pool.
AfterConnect func(context.Context, *pgx.Conn) error
// BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the
// acquision or false to indicate that the connection should be destroyed and a different connection should be
// acquired.
@ -64,6 +68,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) {
// connection.
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
p := &Pool{
afterConnect: config.AfterConnect,
beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
maxConnLifetime: config.MaxConnLifetime,
@ -72,7 +77,22 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
}
p.p = puddle.NewPool(
func(ctx context.Context) (interface{}, error) { return pgx.ConnectConfig(ctx, config.ConnConfig) },
func(ctx context.Context) (interface{}, error) {
conn, err := pgx.ConnectConfig(ctx, config.ConnConfig)
if err != nil {
return nil, err
}
if p.afterConnect != nil {
err = p.afterConnect(ctx, conn)
if err != nil {
conn.Close(ctx)
return nil, err
}
}
return conn, nil
},
func(value interface{}) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
value.(*pgx.Conn).Close(ctx)

View File

@ -51,6 +51,27 @@ func TestPoolAcquireAndConnRelease(t *testing.T) {
c.Release()
}
func TestPoolAfterConnect(t *testing.T) {
t.Parallel()
config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error {
_, err := c.Prepare(ctx, "ps1", "select 1")
return err
}
db, err := pool.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer db.Close()
var n int32
err = db.QueryRow(context.Background(), "ps1").Scan(&n)
require.NoError(t, err)
assert.EqualValues(t, 1, n)
}
func TestPoolBeforeAcquire(t *testing.T) {
t.Parallel()