diff --git a/conn_test.go b/conn_test.go index 6d7f8434..2ead63ce 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "bytes" "context" + "log" "os" "strings" "sync" @@ -784,6 +785,37 @@ func TestLogPassesContext(t *testing.T) { } } +func TestLoggerFunc(t *testing.T) { + t.Parallel() + + const testMsg = "foo" + + buf := bytes.Buffer{} + logger := log.New(&buf, "", 0) + + createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc { + return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { + logger.Printf("%s", testMsg) + } + } + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.Logger = createAdapterFn(logger) + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + buf.Reset() // Clear logs written when establishing connection + + if _, err := conn.Exec(context.TODO(), ";"); err != nil { + t.Fatal(err) + } + + if strings.TrimSpace(buf.String()) != testMsg { + t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() diff --git a/copy_from.go b/copy_from.go index c1a66d52..c5e9aae8 100644 --- a/copy_from.go +++ b/copy_from.go @@ -153,13 +153,13 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { <-doneChan rowsAffected := commandTag.RowsAffected() + endTime := time.Now() if err == nil { if ct.conn.shouldLog(LogLevelInfo) { - endTime := time.Now() ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]any{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) } } else if ct.conn.shouldLog(LogLevelError) { - ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames}) + ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)}) } return rowsAffected, err diff --git a/logger.go b/logger.go index 02a1e8e4..cd1255d1 100644 --- a/logger.go +++ b/logger.go @@ -47,6 +47,14 @@ type Logger interface { Log(ctx context.Context, level LogLevel, msg string, data map[string]any) } +// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + +// Log delegates the logging request to the wrapped function +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { + f(ctx, level, msg, data) +} + // LogLevelFromString converts log level string to constant // // Valid levels: diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index b1b7e951..685cbb62 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -211,17 +211,17 @@ func (ci *pgCustomInt) Scan(src interface{}) error { } func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { - ci := pgtype.NewMap() + m := pgtype.NewMap() src := []byte{0, 42} var v pgCustomInt - plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) err := plan.Scan(src, &v) require.NoError(t, err) require.EqualValues(t, 42, v) ptr := new(pgCustomInt) - plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) err = plan.Scan(src, &ptr) require.NoError(t, err) require.EqualValues(t, 42, *ptr) @@ -232,13 +232,13 @@ func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { assert.Nil(t, ptr) ptr = nil - plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) err = plan.Scan(src, &ptr) require.NoError(t, err) require.EqualValues(t, 42, *ptr) ptr = nil - plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) err = plan.Scan(nil, &ptr) require.NoError(t, err) assert.Nil(t, ptr) @@ -246,10 +246,10 @@ func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { // Test for https://github.com/jackc/pgtype/issues/164 func TestScanPlanInterface(t *testing.T) { - ci := pgtype.NewMap() + m := pgtype.NewMap() src := []byte{0, 42} var v interface{} - plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) + plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) err := plan.Scan(src, v) assert.Error(t, err) } diff --git a/pgxpool/conn.go b/pgxpool/conn.go index b9ff29dc..802026e2 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -2,7 +2,7 @@ package pgxpool import ( "context" - "time" + "sync/atomic" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -26,9 +26,23 @@ func (c *Conn) Release() { res := c.res c.res = nil - now := time.Now() - if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' || (now.Sub(res.CreationTime()) > c.p.maxConnLifetime) { + if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + return + } + + // If the pool is consistently being used, we might never get to check the + // lifetime of a connection since we only check idle connections in checkConnsHealth + // so we also check the lifetime here and force a health check + if c.p.isExpired(res) { + atomic.AddInt64(&c.p.lifetimeDestroyCount, 1) + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() return } @@ -42,6 +56,9 @@ func (c *Conn) Release() { res.Release() } else { res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() } }() } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 65f1cb42..9872e670 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -3,9 +3,11 @@ package pgxpool import ( "context" "fmt" + "math/rand" "runtime" "strconv" "sync" + "sync/atomic" "time" "github.com/jackc/pgx/v5" @@ -70,16 +72,24 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { // Pool allows for connection reuse. type Pool struct { - p *puddle.Pool[*connResource] - config *Config - beforeConnect func(context.Context, *pgx.ConnConfig) error - afterConnect func(context.Context, *pgx.Conn) error - beforeAcquire func(context.Context, *pgx.Conn) bool - afterRelease func(*pgx.Conn) bool - minConns int32 - maxConnLifetime time.Duration - maxConnIdleTime time.Duration - healthCheckPeriod time.Duration + p *puddle.Pool[*connResource] + config *Config + beforeConnect func(context.Context, *pgx.ConnConfig) error + afterConnect func(context.Context, *pgx.Conn) error + beforeAcquire func(context.Context, *pgx.Conn) bool + afterRelease func(*pgx.Conn) bool + minConns int32 + maxConns int32 + maxConnLifetime time.Duration + maxConnLifetimeJitter time.Duration + maxConnIdleTime time.Duration + healthCheckPeriod time.Duration + + healthCheckChan chan struct{} + + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 closeOnce sync.Once closeChan chan struct{} @@ -109,14 +119,19 @@ type Config struct { // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. MaxConnLifetime time.Duration + // MaxConnLifetimeJitter is the duration after MaxConnLifetime to randomly decide to close a connection. + // This helps prevent all connections from being closed at the exact same time, starving the pool. + MaxConnLifetimeJitter time.Duration + // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. MaxConnIdleTime time.Duration // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). MaxConns int32 - // MinConns is the minimum size of the pool. The health check will increase the number of connections to this - // amount if it had dropped below. + // MinConns is the minimum size of the pool. After connection closes, the pool might dip below MinConns. A low + // number of MinConns might mean the pool is empty after MaxConnLifetime until the health check has a chance + // to create new connections. MinConns int32 // HealthCheckPeriod is the duration between checks of the health of idle connections. @@ -157,16 +172,19 @@ func NewConfig(ctx context.Context, config *Config) (*Pool, error) { } p := &Pool{ - config: config, - beforeConnect: config.BeforeConnect, - afterConnect: config.AfterConnect, - beforeAcquire: config.BeforeAcquire, - afterRelease: config.AfterRelease, - minConns: config.MinConns, - maxConnLifetime: config.MaxConnLifetime, - maxConnIdleTime: config.MaxConnIdleTime, - healthCheckPeriod: config.HealthCheckPeriod, - closeChan: make(chan struct{}), + config: config, + beforeConnect: config.BeforeConnect, + afterConnect: config.AfterConnect, + beforeAcquire: config.BeforeAcquire, + afterRelease: config.AfterRelease, + minConns: config.MinConns, + maxConns: config.MaxConns, + maxConnLifetime: config.MaxConnLifetime, + maxConnLifetimeJitter: config.MaxConnLifetimeJitter, + maxConnIdleTime: config.MaxConnIdleTime, + healthCheckPeriod: config.HealthCheckPeriod, + healthCheckChan: make(chan struct{}, 1), + closeChan: make(chan struct{}), } p.p = puddle.NewPool( @@ -216,7 +234,7 @@ func NewConfig(ctx context.Context, config *Config) (*Pool, error) { ) go func() { - p.checkMinConns() // reach min conns as soon as possible + p.createIdleResources(ctx, int(p.minConns)) p.backgroundHealthCheck() }() @@ -231,6 +249,7 @@ func NewConfig(ctx context.Context, config *Config) (*Pool, error) { // pool_max_conn_lifetime: duration string // pool_max_conn_idle_time: duration string // pool_health_check_period: duration string +// pool_max_conn_lifetime_jitter: duration string // // See Config for definitions of these arguments. // @@ -311,6 +330,15 @@ func ParseConfig(connString string) (*Config, error) { config.HealthCheckPeriod = defaultHealthCheckPeriod } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime_jitter"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") + d, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err) + } + config.MaxConnLifetimeJitter = d + } + return config, nil } @@ -323,44 +351,105 @@ func (p *Pool) Close() { }) } +func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { + now := time.Now() + // Small optimization to avoid rand. If it's over lifetime AND jitter, immediately + // return true. + if now.Sub(res.CreationTime()) > p.maxConnLifetime+p.maxConnLifetimeJitter { + return true + } + if p.maxConnLifetimeJitter == 0 { + return false + } + jitterSecs := rand.Float64() * p.maxConnLifetimeJitter.Seconds() + return now.Sub(res.CreationTime()) > p.maxConnLifetime+(time.Duration(jitterSecs)*time.Second) +} + +func (p *Pool) triggerHealthCheck() { + go func() { + // Destroy is asynchronous so we give it time to actually remove itself from + // the pool otherwise we might try to check the pool size too soon + time.Sleep(500 * time.Millisecond) + select { + case p.healthCheckChan <- struct{}{}: + default: + } + }() +} + func (p *Pool) backgroundHealthCheck() { ticker := time.NewTicker(p.healthCheckPeriod) - + defer ticker.Stop() for { select { case <-p.closeChan: - ticker.Stop() return + case <-p.healthCheckChan: + p.checkHealth() case <-ticker.C: - p.checkIdleConnsHealth() - p.checkMinConns() + p.checkHealth() } } } -func (p *Pool) checkIdleConnsHealth() { - resources := p.p.AcquireAllIdle() +func (p *Pool) checkHealth() { + for { + // If checkMinConns failed we don't destroy any connections since we couldn't + // even get to minConns + if err := p.checkMinConns(); err != nil { + // Should we log this error somewhere? + break + } + if !p.checkConnsHealth() { + // Since we didn't destroy any connections we can stop looping + break + } + // Technically Destroy is asynchronous but 500ms should be enough for it to + // remove it from the underlying pool + select { + case <-p.closeChan: + return + case <-time.After(500 * time.Millisecond): + } + } +} - now := time.Now() +// checkConnsHealth will check all idle connections, destroy a connection if +// it's idle or too old, and returns true if any were destroyed +func (p *Pool) checkConnsHealth() bool { + var destroyed bool + totalConns := p.Stat().TotalConns() + resources := p.p.AcquireAllIdle() for _, res := range resources { - if now.Sub(res.CreationTime()) > p.maxConnLifetime { + // We're okay going under minConns if the lifetime is up + if p.isExpired(res) && totalConns >= p.minConns { + atomic.AddInt64(&p.lifetimeDestroyCount, 1) res.Destroy() - } else if res.IdleDuration() > p.maxConnIdleTime { + destroyed = true + // Since Destroy is async we manually decrement totalConns. + totalConns-- + } else if res.IdleDuration() > p.maxConnIdleTime && totalConns > p.minConns { + atomic.AddInt64(&p.idleDestroyCount, 1) res.Destroy() + destroyed = true + // Since Destroy is async we manually decrement totalConns. + totalConns-- } else { res.ReleaseUnused() } } + return destroyed } -func (p *Pool) checkMinConns() { - for i := p.minConns - p.Stat().TotalConns(); i > 0; i-- { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - p.p.CreateResource(ctx) - }() +func (p *Pool) checkMinConns() error { + // TotalConns can include ones that are being destroyed but we should have + // sleep(500ms) around all of the destroys to help prevent that from throwing + // off this check + toCreate := p.minConns - p.Stat().TotalConns() + if toCreate > 0 { + return p.createIdleResources(context.Background(), int(toCreate)) } + return nil } func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { @@ -371,6 +460,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in for i := 0; i < targetResources; i++ { go func() { + atomic.AddInt64(&p.newConnsCount, 1) err := p.p.CreateResource(ctx) errs <- err }() @@ -449,7 +539,12 @@ func (p *Pool) Config() *Config { return p.config.Copy() } // Stat returns a pgxpool.Stat struct with a snapshot of Pool statistics. func (p *Pool) Stat() *Stat { - return &Stat{s: p.p.Stat()} + return &Stat{ + s: p.p.Stat(), + newConnsCount: atomic.LoadInt64(&p.newConnsCount), + lifetimeDestroyCount: atomic.LoadInt64(&p.lifetimeDestroyCount), + idleDestroyCount: atomic.LoadInt64(&p.idleDestroyCount), + } } // Exec acquires a connection from the Pool and executes the given SQL. diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index f296d819..0e4d8acf 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -389,6 +389,14 @@ func TestConnReleaseClosesBusyConn(t *testing.T) { c.Release() waitForReleaseToComplete() + // wait for the connection to actually be destroyed + for i := 0; i < 1000; i++ { + if db.Stat().TotalConns() == 0 { + break + } + time.Sleep(time.Millisecond) + } + stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) } @@ -413,6 +421,8 @@ func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) + assert.EqualValues(t, 0, stats.MaxIdleDestroyCount()) + assert.EqualValues(t, 1, stats.MaxLifetimeDestroyCount()) } func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { @@ -443,6 +453,8 @@ func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) + assert.EqualValues(t, 1, stats.MaxIdleDestroyCount()) + assert.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) } func TestPoolBackgroundChecksMinConns(t *testing.T) { @@ -460,6 +472,21 @@ func TestPoolBackgroundChecksMinConns(t *testing.T) { stats := db.Stat() assert.EqualValues(t, 2, stats.TotalConns()) + assert.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) + assert.EqualValues(t, 2, stats.NewConnsCount()) + + c, err := db.Acquire(context.Background()) + require.NoError(t, err) + err = c.Conn().Close(context.Background()) + require.NoError(t, err) + c.Release() + + time.Sleep(config.HealthCheckPeriod + 500*time.Millisecond) + + stats = db.Stat() + assert.EqualValues(t, 2, stats.TotalConns()) + assert.EqualValues(t, 0, stats.MaxIdleDestroyCount()) + assert.EqualValues(t, 3, stats.NewConnsCount()) } func TestPoolExec(t *testing.T) { @@ -696,6 +723,14 @@ func TestConnReleaseDestroysClosedConn(t *testing.T) { c.Release() waitForReleaseToComplete() + // wait for the connection to actually be destroyed + for i := 0; i < 1000; i++ { + if pool.Stat().TotalConns() == 0 { + break + } + time.Sleep(time.Millisecond) + } + assert.EqualValues(t, 0, pool.Stat().TotalConns()) } @@ -784,7 +819,7 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { require.NoError(t, err) return nil }) - + require.NoError(t, err) return nil }) require.NoError(t, err) @@ -834,6 +869,7 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { return nil }) + require.NoError(t, err) var n int64 err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n) diff --git a/pgxpool/stat.go b/pgxpool/stat.go index 336be42d..47342be4 100644 --- a/pgxpool/stat.go +++ b/pgxpool/stat.go @@ -8,7 +8,10 @@ import ( // Stat is a snapshot of Pool statistics. type Stat struct { - s *puddle.Stat + s *puddle.Stat + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 } // AcquireCount returns the cumulative count of successful acquires from the pool. @@ -62,3 +65,20 @@ func (s *Stat) MaxConns() int32 { func (s *Stat) TotalConns() int32 { return s.s.TotalResources() } + +// NewConnsCount returns the cumulative count of new connections opened. +func (s *Stat) NewConnsCount() int64 { + return s.newConnsCount +} + +// MaxLifetimeDestroyCount returns the cumulative count of connections destroyed +// because they exceeded MaxConnLifetime. +func (s *Stat) MaxLifetimeDestroyCount() int64 { + return s.lifetimeDestroyCount +} + +// MaxIdleDestroyCount returns the cumulative count of connections destroyed because +// they exceeded MaxConnIdleTime. +func (s *Stat) MaxIdleDestroyCount() int64 { + return s.idleDestroyCount +} diff --git a/rows.go b/rows.go index 90a24d28..c91f3aff 100644 --- a/rows.go +++ b/rows.go @@ -162,14 +162,15 @@ func (rows *baseRows) Close() { } if rows.logger != nil { + endTime := time.Now() + if rows.err == nil { if rows.logger.shouldLog(LogLevelInfo) { - endTime := time.Now() rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]any{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else { if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) + rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)}) } } }