mirror of https://github.com/jackc/pgx.git
Add pool AfterRelease hook
Also, just close returned connections that are in a transaction rather than automatically rolling back.pull/483/head
parent
48ea620c93
commit
ac618f105b
16
pool/conn.go
16
pool/conn.go
|
@ -2,7 +2,6 @@ package pool
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
|
@ -12,6 +11,7 @@ import (
|
||||||
// Conn is an acquired *pgx.Conn from a Pool.
|
// Conn is an acquired *pgx.Conn from a Pool.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
res *puddle.Resource
|
res *puddle.Resource
|
||||||
|
p *Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called.
|
// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called.
|
||||||
|
@ -26,22 +26,12 @@ func (c *Conn) Release() {
|
||||||
c.res = nil
|
c.res = nil
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if !conn.IsAlive() {
|
if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' {
|
||||||
res.Destroy()
|
res.Destroy()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.PgConn().TxStatus != 'I' {
|
if c.p.afterRelease == nil || c.p.afterRelease(conn) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
_, err := conn.Exec(ctx, "rollback")
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
|
||||||
res.Destroy()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.IsAlive() {
|
|
||||||
res.Release()
|
res.Release()
|
||||||
} else {
|
} else {
|
||||||
res.Destroy()
|
res.Destroy()
|
||||||
|
|
10
pool/pool.go
10
pool/pool.go
|
@ -16,6 +16,7 @@ const defaultMaxConns = 5
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
p *puddle.Pool
|
p *puddle.Pool
|
||||||
beforeAcquire func(*pgx.Conn) bool
|
beforeAcquire func(*pgx.Conn) bool
|
||||||
|
afterRelease func(*pgx.Conn) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -26,6 +27,10 @@ type Config struct {
|
||||||
// acquired.
|
// acquired.
|
||||||
BeforeAcquire func(*pgx.Conn) bool
|
BeforeAcquire func(*pgx.Conn) bool
|
||||||
|
|
||||||
|
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
|
||||||
|
// return the connection to the pool or false to destroy the connection.
|
||||||
|
AfterRelease func(*pgx.Conn) bool
|
||||||
|
|
||||||
MaxConns int32
|
MaxConns int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +50,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) {
|
||||||
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
|
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
|
||||||
p := &Pool{
|
p := &Pool{
|
||||||
beforeAcquire: config.BeforeAcquire,
|
beforeAcquire: config.BeforeAcquire,
|
||||||
|
afterRelease: config.AfterRelease,
|
||||||
}
|
}
|
||||||
|
|
||||||
p.p = puddle.NewPool(
|
p.p = puddle.NewPool(
|
||||||
|
@ -102,7 +108,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
|
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
|
||||||
return &Conn{res: res}, nil
|
return &Conn{res: res, p: p}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Destroy()
|
res.Destroy()
|
||||||
|
@ -116,7 +122,7 @@ func (p *Pool) AcquireAllIdle() []*Conn {
|
||||||
conns := make([]*Conn, 0, len(resources))
|
conns := make([]*Conn, 0, len(resources))
|
||||||
for _, res := range resources {
|
for _, res := range resources {
|
||||||
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
|
if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) {
|
||||||
conns = append(conns, &Conn{res: res})
|
conns = append(conns, &Conn{res: res, p: p})
|
||||||
} else {
|
} else {
|
||||||
res.Destroy()
|
res.Destroy()
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,6 +92,36 @@ func TestPoolBeforeAcquire(t *testing.T) {
|
||||||
assert.EqualValues(t, 12, acquireAttempts)
|
assert.EqualValues(t, 12, acquireAttempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPoolAfterRelease(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
afterReleaseCount := 0
|
||||||
|
|
||||||
|
config.AfterRelease = func(c *pgx.Conn) bool {
|
||||||
|
afterReleaseCount += 1
|
||||||
|
return afterReleaseCount%2 == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := pool.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
connPIDs := map[uint32]struct{}{}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
conn, err := db.Acquire(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
connPIDs[conn.Conn().PgConn().PID()] = struct{}{}
|
||||||
|
conn.Release()
|
||||||
|
waitForReleaseToComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.EqualValues(t, 5, len(connPIDs))
|
||||||
|
}
|
||||||
|
|
||||||
func TestPoolAcquireAllIdle(t *testing.T) {
|
func TestPoolAcquireAllIdle(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -243,7 +273,7 @@ func TestPoolCopyFrom(t *testing.T) {
|
||||||
assert.Equal(t, inputRows, outputRows)
|
assert.Equal(t, inputRows, outputRows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
@ -275,13 +305,13 @@ func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
||||||
c, err = pool.Acquire(ctx)
|
c, err = pool.Acquire(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, pid, c.Conn().PgConn().PID())
|
assert.NotEqual(t, pid, c.Conn().PgConn().PID())
|
||||||
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
|
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
|
||||||
|
|
||||||
c.Release()
|
c.Release()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnReleaseRollsBackInTransaction(t *testing.T) {
|
func TestConnReleaseClosesConnInTransaction(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
@ -308,7 +338,7 @@ func TestConnReleaseRollsBackInTransaction(t *testing.T) {
|
||||||
c, err = pool.Acquire(ctx)
|
c, err = pool.Acquire(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, pid, c.Conn().PgConn().PID())
|
assert.NotEqual(t, pid, c.Conn().PgConn().PID())
|
||||||
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
|
assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus)
|
||||||
|
|
||||||
c.Release()
|
c.Release()
|
||||||
|
|
Loading…
Reference in New Issue