Add pool AfterRelease hook

Also, just close returned connections that are in a transaction rather
than automatically rolling back.
pull/483/head
Jack Christensen 2019-04-27 09:01:32 -05:00
parent 48ea620c93
commit ac618f105b
3 changed files with 45 additions and 19 deletions

View File

@ -2,7 +2,6 @@ package pool
import ( import (
"context" "context"
"time"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
@ -12,6 +11,7 @@ import (
// Conn is an acquired *pgx.Conn from a Pool. // Conn is an acquired *pgx.Conn from a Pool.
type Conn struct { type Conn struct {
res *puddle.Resource res *puddle.Resource
p *Pool
} }
// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. // Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called.
@ -26,22 +26,12 @@ func (c *Conn) Release() {
c.res = nil c.res = nil
go func() { go func() {
if !conn.IsAlive() { if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' {
res.Destroy() res.Destroy()
return return
} }
if conn.PgConn().TxStatus != 'I' { if c.p.afterRelease == nil || c.p.afterRelease(conn) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := conn.Exec(ctx, "rollback")
cancel()
if err != nil {
res.Destroy()
return
}
}
if conn.IsAlive() {
res.Release() res.Release()
} else { } else {
res.Destroy() res.Destroy()

View File

@ -16,6 +16,7 @@ const defaultMaxConns = 5
type Pool struct { type Pool struct {
p *puddle.Pool p *puddle.Pool
beforeAcquire func(*pgx.Conn) bool beforeAcquire func(*pgx.Conn) bool
afterRelease func(*pgx.Conn) bool
} }
type Config struct { type Config struct {
@ -26,6 +27,10 @@ type Config struct {
// acquired. // acquired.
BeforeAcquire func(*pgx.Conn) bool BeforeAcquire func(*pgx.Conn) bool
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
// return the connection to the pool or false to destroy the connection.
AfterRelease func(*pgx.Conn) bool
MaxConns int32 MaxConns int32
} }
@ -45,6 +50,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) {
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
p := &Pool{ p := &Pool{
beforeAcquire: config.BeforeAcquire, beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
} }
p.p = puddle.NewPool( p.p = puddle.NewPool(
@ -102,7 +108,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
} }
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) { if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
return &Conn{res: res}, nil return &Conn{res: res, p: p}, nil
} }
res.Destroy() res.Destroy()
@ -116,7 +122,7 @@ func (p *Pool) AcquireAllIdle() []*Conn {
conns := make([]*Conn, 0, len(resources)) conns := make([]*Conn, 0, len(resources))
for _, res := range resources { for _, res := range resources {
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) { if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
conns = append(conns, &Conn{res: res}) conns = append(conns, &Conn{res: res, p: p})
} else { } else {
res.Destroy() res.Destroy()
} }

View File

@ -92,6 +92,36 @@ func TestPoolBeforeAcquire(t *testing.T) {
assert.EqualValues(t, 12, acquireAttempts) assert.EqualValues(t, 12, acquireAttempts)
} }
func TestPoolAfterRelease(t *testing.T) {
t.Parallel()
config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
afterReleaseCount := 0
config.AfterRelease = func(c *pgx.Conn) bool {
afterReleaseCount += 1
return afterReleaseCount%2 == 1
}
db, err := pool.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer db.Close()
connPIDs := map[uint32]struct{}{}
for i := 0; i < 10; i++ {
conn, err := db.Acquire(context.Background())
assert.NoError(t, err)
connPIDs[conn.Conn().PgConn().PID()] = struct{}{}
conn.Release()
waitForReleaseToComplete()
}
assert.EqualValues(t, 5, len(connPIDs))
}
func TestPoolAcquireAllIdle(t *testing.T) { func TestPoolAcquireAllIdle(t *testing.T) {
t.Parallel() t.Parallel()
@ -243,7 +273,7 @@ func TestPoolCopyFrom(t *testing.T) {
assert.Equal(t, inputRows, outputRows) assert.Equal(t, inputRows, outputRows)
} }
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) { func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -275,13 +305,13 @@ func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
c, err = pool.Acquire(ctx) c, err = pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, pid, c.Conn().PgConn().PID()) assert.NotEqual(t, pid, c.Conn().PgConn().PID())
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus) assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
c.Release() c.Release()
} }
func TestConnReleaseRollsBackInTransaction(t *testing.T) { func TestConnReleaseClosesConnInTransaction(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -308,7 +338,7 @@ func TestConnReleaseRollsBackInTransaction(t *testing.T) {
c, err = pool.Acquire(ctx) c, err = pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, pid, c.Conn().PgConn().PID()) assert.NotEqual(t, pid, c.Conn().PgConn().PID())
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus) assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
c.Release() c.Release()