From f20f026b7d5539bf31c1917dd699695962abc02d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 08:57:03 -0500 Subject: [PATCH] Pool BeforeAcquire hook takes context --- pgxpool/pool.go | 10 +++++----- pgxpool/pool_test.go | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 1e3d80c5..dd5a48b1 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -68,7 +68,7 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { type Pool struct { p *puddle.Pool afterConnect func(context.Context, *pgx.Conn) error - beforeAcquire func(*pgx.Conn) bool + beforeAcquire func(context.Context, *pgx.Conn) bool afterRelease func(*pgx.Conn) bool maxConnLifetime time.Duration healthCheckPeriod time.Duration @@ -86,7 +86,7 @@ type Config struct { // BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the // acquision or false to indicate that the connection should be destroyed and a different connection should be // acquired. - BeforeAcquire func(*pgx.Conn) bool + BeforeAcquire func(context.Context, *pgx.Conn) bool // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // return the connection to the pool or false to destroy the connection. @@ -289,7 +289,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { } cr := res.Value().(*connResource) - if p.beforeAcquire == nil || p.beforeAcquire(cr.conn) { + if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { return cr.getConn(p, res), nil } @@ -299,12 +299,12 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { // AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and // keep-alive functionality. It does not update pool statistics. -func (p *Pool) AcquireAllIdle() []*Conn { +func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { resources := p.p.AcquireAllIdle() conns := make([]*Conn, 0, len(resources)) for _, res := range resources { cr := res.Value().(*connResource) - if p.beforeAcquire == nil || p.beforeAcquire(cr.conn) { + if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { conns = append(conns, cr.getConn(p, res)) } else { res.Destroy() diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 3c229bb9..d6640a71 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -88,7 +88,7 @@ func TestPoolBeforeAcquire(t *testing.T) { acquireAttempts := 0 - config.BeforeAcquire = func(c *pgx.Conn) bool { + config.BeforeAcquire = func(ctx context.Context, c *pgx.Conn) bool { acquireAttempts += 1 return acquireAttempts%2 == 0 } @@ -110,7 +110,7 @@ func TestPoolBeforeAcquire(t *testing.T) { assert.EqualValues(t, 8, acquireAttempts) - conns = db.AcquireAllIdle() + conns = db.AcquireAllIdle(context.Background()) assert.Len(t, conns, 2) for _, c := range conns { @@ -158,7 +158,7 @@ func TestPoolAcquireAllIdle(t *testing.T) { require.NoError(t, err) defer db.Close() - conns := db.AcquireAllIdle() + conns := db.AcquireAllIdle(context.Background()) assert.Len(t, conns, 1) for _, c := range conns { @@ -179,7 +179,7 @@ func TestPoolAcquireAllIdle(t *testing.T) { } waitForReleaseToComplete() - conns = db.AcquireAllIdle() + conns = db.AcquireAllIdle(context.Background()) assert.Len(t, conns, 3) for _, c := range conns {