Pool BeforeAcquire hook takes context

pull/594/head
Jack Christensen 2019-08-31 08:57:03 -05:00
parent 486d64daed
commit f20f026b7d
2 changed files with 9 additions and 9 deletions

View File

@ -68,7 +68,7 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows {
type Pool struct { type Pool struct {
p *puddle.Pool p *puddle.Pool
afterConnect func(context.Context, *pgx.Conn) error afterConnect func(context.Context, *pgx.Conn) error
beforeAcquire func(*pgx.Conn) bool beforeAcquire func(context.Context, *pgx.Conn) bool
afterRelease func(*pgx.Conn) bool afterRelease func(*pgx.Conn) bool
maxConnLifetime time.Duration maxConnLifetime time.Duration
healthCheckPeriod time.Duration healthCheckPeriod time.Duration
@ -86,7 +86,7 @@ type Config struct {
// BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the // BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the
// acquision or false to indicate that the connection should be destroyed and a different connection should be // acquision or false to indicate that the connection should be destroyed and a different connection should be
// acquired. // acquired.
BeforeAcquire func(*pgx.Conn) bool BeforeAcquire func(context.Context, *pgx.Conn) bool
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
// return the connection to the pool or false to destroy the connection. // return the connection to the pool or false to destroy the connection.
@ -289,7 +289,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
} }
cr := res.Value().(*connResource) cr := res.Value().(*connResource)
if p.beforeAcquire == nil || p.beforeAcquire(cr.conn) { if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
return cr.getConn(p, res), nil return cr.getConn(p, res), nil
} }
@ -299,12 +299,12 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
// AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and // AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and
// keep-alive functionality. It does not update pool statistics. // keep-alive functionality. It does not update pool statistics.
func (p *Pool) AcquireAllIdle() []*Conn { func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
resources := p.p.AcquireAllIdle() resources := p.p.AcquireAllIdle()
conns := make([]*Conn, 0, len(resources)) conns := make([]*Conn, 0, len(resources))
for _, res := range resources { for _, res := range resources {
cr := res.Value().(*connResource) cr := res.Value().(*connResource)
if p.beforeAcquire == nil || p.beforeAcquire(cr.conn) { if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
conns = append(conns, cr.getConn(p, res)) conns = append(conns, cr.getConn(p, res))
} else { } else {
res.Destroy() res.Destroy()

View File

@ -88,7 +88,7 @@ func TestPoolBeforeAcquire(t *testing.T) {
acquireAttempts := 0 acquireAttempts := 0
config.BeforeAcquire = func(c *pgx.Conn) bool { config.BeforeAcquire = func(ctx context.Context, c *pgx.Conn) bool {
acquireAttempts += 1 acquireAttempts += 1
return acquireAttempts%2 == 0 return acquireAttempts%2 == 0
} }
@ -110,7 +110,7 @@ func TestPoolBeforeAcquire(t *testing.T) {
assert.EqualValues(t, 8, acquireAttempts) assert.EqualValues(t, 8, acquireAttempts)
conns = db.AcquireAllIdle() conns = db.AcquireAllIdle(context.Background())
assert.Len(t, conns, 2) assert.Len(t, conns, 2)
for _, c := range conns { for _, c := range conns {
@ -158,7 +158,7 @@ func TestPoolAcquireAllIdle(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
conns := db.AcquireAllIdle() conns := db.AcquireAllIdle(context.Background())
assert.Len(t, conns, 1) assert.Len(t, conns, 1)
for _, c := range conns { for _, c := range conns {
@ -179,7 +179,7 @@ func TestPoolAcquireAllIdle(t *testing.T) {
} }
waitForReleaseToComplete() waitForReleaseToComplete()
conns = db.AcquireAllIdle() conns = db.AcquireAllIdle(context.Background())
assert.Len(t, conns, 3) assert.Len(t, conns, 3)
for _, c := range conns { for _, c := range conns {