From 7558b8d05fdaee9c47df5ce8dfa740b32aa29eb8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Apr 2019 16:09:23 -0500 Subject: [PATCH] Add AfterConnect hook to pool --- pool/pool.go | 22 +++++++++++++++++++++- pool/pool_test.go | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/pool/pool.go b/pool/pool.go index dcd540f3..360ef24e 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -18,6 +18,7 @@ var defaultHealthCheckPeriod = time.Minute type Pool struct { p *puddle.Pool + afterConnect func(context.Context, *pgx.Conn) error beforeAcquire func(*pgx.Conn) bool afterRelease func(*pgx.Conn) bool maxConnLifetime time.Duration @@ -30,6 +31,9 @@ type Pool struct { type Config struct { ConnConfig *pgx.ConnConfig + // AfterConnect is called after a connection is established, but before it is added to the pool. + AfterConnect func(context.Context, *pgx.Conn) error + // BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the // acquision or false to indicate that the connection should be destroyed and a different connection should be // acquired. @@ -64,6 +68,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) { // connection. func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { p := &Pool{ + afterConnect: config.AfterConnect, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, maxConnLifetime: config.MaxConnLifetime, @@ -72,7 +77,22 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { } p.p = puddle.NewPool( - func(ctx context.Context) (interface{}, error) { return pgx.ConnectConfig(ctx, config.ConnConfig) }, + func(ctx context.Context) (interface{}, error) { + conn, err := pgx.ConnectConfig(ctx, config.ConnConfig) + if err != nil { + return nil, err + } + + if p.afterConnect != nil { + err = p.afterConnect(ctx, conn) + if err != nil { + conn.Close(ctx) + return nil, err + } + } + + return conn, nil + }, func(value interface{}) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) value.(*pgx.Conn).Close(ctx) diff --git a/pool/pool_test.go b/pool/pool_test.go index ffaf6932..e20e9cc0 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -51,6 +51,27 @@ func TestPoolAcquireAndConnRelease(t *testing.T) { c.Release() } +func TestPoolAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { + _, err := c.Prepare(ctx, "ps1", "select 1") + return err + } + + db, err := pool.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer db.Close() + + var n int32 + err = db.QueryRow(context.Background(), "ps1").Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, 1, n) +} + func TestPoolBeforeAcquire(t *testing.T) { t.Parallel()