diff --git a/pool/conn.go b/pool/conn.go index e334719e..8194945c 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -2,6 +2,7 @@ package pool import ( "context" + "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" @@ -26,7 +27,8 @@ func (c *Conn) Release() { c.res = nil go func() { - if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' { + now := time.Now() + if !conn.IsAlive() || conn.PgConn().TxStatus != 'I' || (now.Sub(res.CreationTime()) > c.p.maxConnLifetime) { res.Destroy() return } diff --git a/pool/pool.go b/pool/pool.go index b74abf09..eeb3ea26 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -2,7 +2,6 @@ package pool import ( "context" - "fmt" "runtime" "strconv" "time" @@ -10,14 +9,20 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/jackc/puddle" + errors "golang.org/x/xerrors" ) -const defaultMaxConns = 5 +var defaultMinMaxConns = int32(4) +var defaultMaxConnLifetime = time.Hour +var defaultHealthCheckPeriod = time.Minute type Pool struct { - p *puddle.Pool - beforeAcquire func(*pgx.Conn) bool - afterRelease func(*pgx.Conn) bool + p *puddle.Pool + beforeAcquire func(*pgx.Conn) bool + afterRelease func(*pgx.Conn) bool + maxConnLifetime time.Duration + healthCheckPeriod time.Duration + closeChan chan struct{} } type Config struct { @@ -32,7 +37,14 @@ type Config struct { // return the connection to the pool or false to destroy the connection. AfterRelease func(*pgx.Conn) bool + // MaxConnLifetime is the duration after which a connection will be automatically closed. + MaxConnLifetime time.Duration + + // MaxConns is the maximum size of the pool. MaxConns int32 + + // HealthCheckPeriod is the duration between checks of the health of idle connections. + HealthCheckPeriod time.Duration } // Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial @@ -50,8 +62,11 @@ func Connect(ctx context.Context, connString string) (*Pool, error) { // connection. func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { p := &Pool{ - beforeAcquire: config.BeforeAcquire, - afterRelease: config.AfterRelease, + beforeAcquire: config.BeforeAcquire, + afterRelease: config.AfterRelease, + maxConnLifetime: config.MaxConnLifetime, + healthCheckPeriod: config.HealthCheckPeriod, + closeChan: make(chan struct{}), } p.p = puddle.NewPool( @@ -64,6 +79,8 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { config.MaxConns, ) + go p.backgroundHealthCheck() + // Initially establish one connection res, err := p.p.Acquire(ctx) if err != nil { @@ -87,25 +104,78 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conns: %v", err) + return nil, errors.Errorf("cannot parse pool_max_conns: %w", err) + } + if n < 1 { + return nil, errors.Errorf("pool_max_conns too small: %d", n) } config.MaxConns = int32(n) } else { - config.MaxConns = 4 - if int32(runtime.NumCPU()) > config.MaxConns { - config.MaxConns = runtime.NumCPU() + config.MaxConns = defaultMinMaxConns + if numCPU := int32(runtime.NumCPU()); numCPU > config.MaxConns { + config.MaxConns = numCPU } } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") + d, err := time.ParseDuration(s) + if err != nil { + return nil, errors.Errorf("invalid pool_max_conn_lifetime: %w", err) + } + config.MaxConnLifetime = d + } else { + config.MaxConnLifetime = defaultMaxConnLifetime + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_health_check_period"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_health_check_period") + d, err := time.ParseDuration(s) + if err != nil { + return nil, errors.Errorf("invalid pool_health_check_period: %w", err) + } + config.HealthCheckPeriod = d + } else { + config.HealthCheckPeriod = defaultHealthCheckPeriod + } + return config, nil } // Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned // to pool and closed. func (p *Pool) Close() { + close(p.closeChan) p.p.Close() } +func (p *Pool) backgroundHealthCheck() { + ticker := time.NewTicker(p.healthCheckPeriod) + + for { + select { + case <-p.closeChan: + ticker.Stop() + return + case <-ticker.C: + p.checkIdleConnsHealth() + } + } +} + +func (p *Pool) checkIdleConnsHealth() { + resources := p.p.AcquireAllIdle() + + now := time.Now() + for _, res := range resources { + if now.Sub(res.CreationTime()) > p.maxConnLifetime { + res.Destroy() + } else { + res.Release() + } + } +} + func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { for { res, err := p.p.Acquire(ctx) diff --git a/pool/pool_test.go b/pool/pool_test.go index 7ddb94d3..ffaf6932 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -158,6 +158,50 @@ func TestPoolAcquireAllIdle(t *testing.T) { } } +func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { + t.Parallel() + + config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MaxConnLifetime = 250 * time.Millisecond + + db, err := pool.ConnectConfig(context.Background(), config) + defer db.Close() + + c, err := db.Acquire(context.Background()) + require.NoError(t, err) + + time.Sleep(config.MaxConnLifetime) + + c.Release() + waitForReleaseToComplete() + + stats := db.Stat() + assert.EqualValues(t, 0, stats.TotalConns()) +} + +func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { + t.Parallel() + + config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MaxConnLifetime = 100 * time.Millisecond + config.HealthCheckPeriod = 100 * time.Millisecond + + db, err := pool.ConnectConfig(context.Background(), config) + defer db.Close() + + c, err := db.Acquire(context.Background()) + require.NoError(t, err) + c.Release() + time.Sleep(config.MaxConnLifetime + 50*time.Millisecond) + + stats := db.Stat() + assert.EqualValues(t, 0, stats.TotalConns()) +} + func TestPoolExec(t *testing.T) { t.Parallel()