From 7f2bb9595f7009f6a0ba40e12d20123173907864 Mon Sep 17 00:00:00 2001 From: Evan Cordell Date: Mon, 1 May 2023 09:28:03 -0400 Subject: [PATCH] add BeforeClose to pgxpool.Pool --- pgxpool/pool.go | 10 +++++++++- pgxpool/pool_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 7649488e..cc837017 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -85,6 +85,7 @@ type Pool struct { afterConnect func(context.Context, *pgx.Conn) error beforeAcquire func(context.Context, *pgx.Conn) bool afterRelease func(*pgx.Conn) bool + beforeClose func(*pgx.Conn) minConns int32 maxConns int32 maxConnLifetime time.Duration @@ -111,7 +112,7 @@ type Config struct { AfterConnect func(context.Context, *pgx.Conn) error // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the - // acquision or false to indicate that the connection should be destroyed and a different connection should be + // acquisition or false to indicate that the connection should be destroyed and a different connection should be // acquired. BeforeAcquire func(context.Context, *pgx.Conn) bool @@ -119,6 +120,9 @@ type Config struct { // return the connection to the pool or false to destroy the connection. AfterRelease func(*pgx.Conn) bool + // BeforeClose is called right before a connection is closed and removed from the pool. + BeforeClose func(*pgx.Conn) + // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. MaxConnLifetime time.Duration @@ -180,6 +184,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { afterConnect: config.AfterConnect, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, + beforeClose: config.BeforeClose, minConns: config.MinConns, maxConns: config.MaxConns, maxConnLifetime: config.MaxConnLifetime, @@ -236,6 +241,9 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { Destructor: func(value *connResource) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) conn := value.conn + if p.beforeClose != nil { + p.beforeClose(conn) + } conn.Close(ctx) select { case <-conn.PgConn().CleanupDone(): diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 2ceb33cf..315897d6 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -347,6 +347,49 @@ func TestPoolAfterRelease(t *testing.T) { assert.EqualValues(t, 5, len(connPIDs)) } +func TestPoolBeforeClose(t *testing.T) { + t.Parallel() + + func() { + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + err = pool.AcquireFunc(context.Background(), func(conn *pgxpool.Conn) error { + if conn.Conn().PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support backend PID") + } + return nil + }) + require.NoError(t, err) + }() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + connPIDs := make(chan uint32, 5) + config.BeforeClose = func(c *pgx.Conn) { + connPIDs <- c.PgConn().PID() + } + + db, err := pgxpool.NewWithConfig(context.Background(), config) + require.NoError(t, err) + defer db.Close() + + acquiredPIDs := make([]uint32, 0, 5) + closedPIDs := make([]uint32, 0, 5) + for i := 0; i < 5; i++ { + conn, err := db.Acquire(context.Background()) + assert.NoError(t, err) + acquiredPIDs = append(acquiredPIDs, conn.Conn().PgConn().PID()) + conn.Release() + db.Reset() + closedPIDs = append(closedPIDs, <-connPIDs) + } + + assert.ElementsMatch(t, acquiredPIDs, closedPIDs) +} + func TestPoolAcquireAllIdle(t *testing.T) { t.Parallel()