diff --git a/pool/conn.go b/pool/conn.go index 05231790..e334719e 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -2,7 +2,6 @@ package pool import ( "context" - "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" @@ -12,6 +11,7 @@ import ( // Conn is an acquired *pgx.Conn from a Pool. type Conn struct { res *puddle.Resource + p *Pool } // Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. @@ -26,22 +26,12 @@ func (c *Conn) Release() { c.res = nil go func() { - if !conn.IsAlive() { + if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' { res.Destroy() return } - if conn.PgConn().TxStatus != 'I' { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := conn.Exec(ctx, "rollback") - cancel() - if err != nil { - res.Destroy() - return - } - } - - if conn.IsAlive() { + if c.p.afterRelease == nil || c.p.afterRelease(conn) { res.Release() } else { res.Destroy() diff --git a/pool/pool.go b/pool/pool.go index d90f8e26..af2665ff 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -16,6 +16,7 @@ const defaultMaxConns = 5 type Pool struct { p *puddle.Pool beforeAcquire func(*pgx.Conn) bool + afterRelease func(*pgx.Conn) bool } type Config struct { @@ -26,6 +27,10 @@ type Config struct { // acquired. BeforeAcquire func(*pgx.Conn) bool + // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to + // return the connection to the pool or false to destroy the connection. + AfterRelease func(*pgx.Conn) bool + MaxConns int32 } @@ -45,6 +50,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) { func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { p := &Pool{ beforeAcquire: config.BeforeAcquire, + afterRelease: config.AfterRelease, } p.p = puddle.NewPool( @@ -102,7 +108,7 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { } if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) { - return &Conn{res: res}, nil + return &Conn{res: res, p: p}, nil } res.Destroy() @@ -116,7 +122,7 @@ func (p *Pool) AcquireAllIdle() []*Conn { conns := make([]*Conn, 0, len(resources)) for _, res := range resources { if p.beforeAcquire == nil || p.beforeAcquire(res.Value().(*pgx.Conn)) { - conns = append(conns, &Conn{res: res}) + conns = append(conns, &Conn{res: res, p: p}) } else { res.Destroy() } diff --git a/pool/pool_test.go b/pool/pool_test.go index 45464c18..7ddb94d3 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -92,6 +92,36 @@ func TestPoolBeforeAcquire(t *testing.T) { assert.EqualValues(t, 12, acquireAttempts) } +func TestPoolAfterRelease(t *testing.T) { + t.Parallel() + + config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + afterReleaseCount := 0 + + config.AfterRelease = func(c *pgx.Conn) bool { + afterReleaseCount += 1 + return afterReleaseCount%2 == 1 + } + + db, err := pool.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer db.Close() + + connPIDs := map[uint32]struct{}{} + + for i := 0; i < 10; i++ { + conn, err := db.Acquire(context.Background()) + assert.NoError(t, err) + connPIDs[conn.Conn().PgConn().PID()] = struct{}{} + conn.Release() + waitForReleaseToComplete() + } + + assert.EqualValues(t, 5, len(connPIDs)) +} + func TestPoolAcquireAllIdle(t *testing.T) { t.Parallel() @@ -243,7 +273,7 @@ func TestPoolCopyFrom(t *testing.T) { assert.Equal(t, inputRows, outputRows) } -func TestConnReleaseRollsBackFailedTransaction(t *testing.T) { +func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -275,13 +305,13 @@ func TestConnReleaseRollsBackFailedTransaction(t *testing.T) { c, err = pool.Acquire(ctx) require.NoError(t, err) - assert.Equal(t, pid, c.Conn().PgConn().PID()) + assert.NotEqual(t, pid, c.Conn().PgConn().PID()) assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus) c.Release() } -func TestConnReleaseRollsBackInTransaction(t *testing.T) { +func TestConnReleaseClosesConnInTransaction(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -308,7 +338,7 @@ func TestConnReleaseRollsBackInTransaction(t *testing.T) { c, err = pool.Acquire(ctx) require.NoError(t, err) - assert.Equal(t, pid, c.Conn().PgConn().PID()) + assert.NotEqual(t, pid, c.Conn().PgConn().PID()) assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus) c.Release()