mirror of https://github.com/jackc/pgx.git
Add pool AfterRelease hook
Also, just close returned connections that are in a transaction rather than automatically rolling back.pull/483/head
parent
48ea620c93
commit
ac618f105b
16
pool/conn.go
16
pool/conn.go
|
@ -2,7 +2,6 @@ package pool
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
|
@ -12,6 +11,7 @@ import (
|
|||
// Conn is an acquired *pgx.Conn from a Pool.
|
||||
type Conn struct {
|
||||
res *puddle.Resource
|
||||
p *Pool
|
||||
}
|
||||
|
||||
// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called.
|
||||
|
@ -26,22 +26,12 @@ func (c *Conn) Release() {
|
|||
c.res = nil
|
||||
|
||||
go func() {
|
||||
if !conn.IsAlive() {
|
||||
if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' {
|
||||
res.Destroy()
|
||||
return
|
||||
}
|
||||
|
||||
if conn.PgConn().TxStatus != 'I' {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_, err := conn.Exec(ctx, "rollback")
|
||||
cancel()
|
||||
if err != nil {
|
||||
res.Destroy()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
if c.p.afterRelease == nil || c.p.afterRelease(conn) {
|
||||
res.Release()
|
||||
} else {
|
||||
res.Destroy()
|
||||
|
|
10
pool/pool.go
10
pool/pool.go
|
@ -16,6 +16,7 @@ const defaultMaxConns = 5
|
|||
type Pool struct {
|
||||
p *puddle.Pool
|
||||
beforeAcquire func(*pgx.Conn) bool
|
||||
afterRelease func(*pgx.Conn) bool
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
@ -26,6 +27,10 @@ type Config struct {
|
|||
// acquired.
|
||||
BeforeAcquire func(*pgx.Conn) bool
|
||||
|
||||
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
|
||||
// return the connection to the pool or false to destroy the connection.
|
||||
AfterRelease func(*pgx.Conn) bool
|
||||
|
||||
MaxConns int32
|
||||
}
|
||||
|
||||
|
@ -45,6 +50,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) {
|
|||
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
|
||||
p := &Pool{
|
||||
beforeAcquire: config.BeforeAcquire,
|
||||
afterRelease: config.AfterRelease,
|
||||
}
|
||||
|
||||
p.p = puddle.NewPool(
|
||||
|
@ -102,7 +108,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
|
|||
}
|
||||
|
||||
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
|
||||
return &Conn{res: res}, nil
|
||||
return &Conn{res: res, p: p}, nil
|
||||
}
|
||||
|
||||
res.Destroy()
|
||||
|
@ -116,7 +122,7 @@ func (p *Pool) AcquireAllIdle() []*Conn {
|
|||
conns := make([]*Conn, 0, len(resources))
|
||||
for _, res := range resources {
|
||||
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
|
||||
conns = append(conns, &Conn{res: res})
|
||||
conns = append(conns, &Conn{res: res, p: p})
|
||||
} else {
|
||||
res.Destroy()
|
||||
}
|
||||
|
|
|
@ -92,6 +92,36 @@ func TestPoolBeforeAcquire(t *testing.T) {
|
|||
assert.EqualValues(t, 12, acquireAttempts)
|
||||
}
|
||||
|
||||
func TestPoolAfterRelease(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
afterReleaseCount := 0
|
||||
|
||||
config.AfterRelease = func(c *pgx.Conn) bool {
|
||||
afterReleaseCount += 1
|
||||
return afterReleaseCount%2 == 1
|
||||
}
|
||||
|
||||
db, err := pool.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
connPIDs := map[uint32]struct{}{}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
conn, err := db.Acquire(context.Background())
|
||||
assert.NoError(t, err)
|
||||
connPIDs[conn.Conn().PgConn().PID()] = struct{}{}
|
||||
conn.Release()
|
||||
waitForReleaseToComplete()
|
||||
}
|
||||
|
||||
assert.EqualValues(t, 5, len(connPIDs))
|
||||
}
|
||||
|
||||
func TestPoolAcquireAllIdle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -243,7 +273,7 @@ func TestPoolCopyFrom(t *testing.T) {
|
|||
assert.Equal(t, inputRows, outputRows)
|
||||
}
|
||||
|
||||
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
||||
func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
@ -275,13 +305,13 @@ func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
|||
c, err = pool.Acquire(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, pid, c.Conn().PgConn().PID())
|
||||
assert.NotEqual(t, pid, c.Conn().PgConn().PID())
|
||||
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
|
||||
|
||||
c.Release()
|
||||
}
|
||||
|
||||
func TestConnReleaseRollsBackInTransaction(t *testing.T) {
|
||||
func TestConnReleaseClosesConnInTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
@ -308,7 +338,7 @@ func TestConnReleaseRollsBackInTransaction(t *testing.T) {
|
|||
c, err = pool.Acquire(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, pid, c.Conn().PgConn().PID())
|
||||
assert.NotEqual(t, pid, c.Conn().PgConn().PID())
|
||||
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
|
||||
|
||||
c.Release()
|
||||
|
|
Loading…
Reference in New Issue