Add pool AfterRelease hook

Also, just close returned connections that are in a transaction rather
than automatically rolling back.
pull/483/head
Jack Christensen 2019-04-27 09:01:32 -05:00
parent 48ea620c93
commit ac618f105b
3 changed files with 45 additions and 19 deletions

View File

@ -2,7 +2,6 @@ package pool
import (
"context"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
@ -12,6 +11,7 @@ import (
// Conn is an acquired *pgx.Conn from a Pool.
type Conn struct {
res *puddle.Resource
p *Pool
}
// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called.
@ -26,22 +26,12 @@ func (c *Conn) Release() {
c.res = nil
go func() {
if !conn.IsAlive() {
if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' {
res.Destroy()
return
}
if conn.PgConn().TxStatus != 'I' {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := conn.Exec(ctx, "rollback")
cancel()
if err != nil {
res.Destroy()
return
}
}
if conn.IsAlive() {
if c.p.afterRelease == nil || c.p.afterRelease(conn) {
res.Release()
} else {
res.Destroy()

View File

@ -16,6 +16,7 @@ const defaultMaxConns = 5
type Pool struct {
p *puddle.Pool
beforeAcquire func(*pgx.Conn) bool
afterRelease func(*pgx.Conn) bool
}
type Config struct {
@ -26,6 +27,10 @@ type Config struct {
// acquired.
BeforeAcquire func(*pgx.Conn) bool
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
// return the connection to the pool or false to destroy the connection.
AfterRelease func(*pgx.Conn) bool
MaxConns int32
}
@ -45,6 +50,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) {
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
p := &Pool{
beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
}
p.p = puddle.NewPool(
@ -102,7 +108,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
}
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
return &Conn{res: res}, nil
return &Conn{res: res, p: p}, nil
}
res.Destroy()
@ -116,7 +122,7 @@ func (p *Pool) AcquireAllIdle() []*Conn {
conns := make([]*Conn, 0, len(resources))
for _, res := range resources {
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
conns = append(conns, &Conn{res: res})
conns = append(conns, &Conn{res: res, p: p})
} else {
res.Destroy()
}

View File

@ -92,6 +92,36 @@ func TestPoolBeforeAcquire(t *testing.T) {
assert.EqualValues(t, 12, acquireAttempts)
}
func TestPoolAfterRelease(t *testing.T) {
t.Parallel()
config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
afterReleaseCount := 0
config.AfterRelease = func(c *pgx.Conn) bool {
afterReleaseCount += 1
return afterReleaseCount%2 == 1
}
db, err := pool.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer db.Close()
connPIDs := map[uint32]struct{}{}
for i := 0; i < 10; i++ {
conn, err := db.Acquire(context.Background())
assert.NoError(t, err)
connPIDs[conn.Conn().PgConn().PID()] = struct{}{}
conn.Release()
waitForReleaseToComplete()
}
assert.EqualValues(t, 5, len(connPIDs))
}
func TestPoolAcquireAllIdle(t *testing.T) {
t.Parallel()
@ -243,7 +273,7 @@ func TestPoolCopyFrom(t *testing.T) {
assert.Equal(t, inputRows, outputRows)
}
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -275,13 +305,13 @@ func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
c, err = pool.Acquire(ctx)
require.NoError(t, err)
assert.Equal(t, pid, c.Conn().PgConn().PID())
assert.NotEqual(t, pid, c.Conn().PgConn().PID())
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
c.Release()
}
func TestConnReleaseRollsBackInTransaction(t *testing.T) {
func TestConnReleaseClosesConnInTransaction(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -308,7 +338,7 @@ func TestConnReleaseRollsBackInTransaction(t *testing.T) {
c, err = pool.Acquire(ctx)
require.NoError(t, err)
assert.Equal(t, pid, c.Conn().PgConn().PID())
assert.NotEqual(t, pid, c.Conn().PgConn().PID())
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
c.Release()