From d4258bb47f9f702b7ecea66147c5dcf47c0c2732 Mon Sep 17 00:00:00 2001 From: konstantin Date: Mon, 11 Apr 2016 13:35:16 -0700 Subject: [PATCH] Add AcquireTimeout support --- conn_pool.go | 32 +++++++++-- conn_pool_test.go | 139 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 148 insertions(+), 23 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 2540d0d8..e6b44934 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -3,12 +3,14 @@ package pgx import ( "errors" "sync" + "time" ) type ConnPoolConfig struct { ConnConfig MaxConnections int // max simultaneous connections to use, default 5, must be at least 2 AfterConnect func(*Conn) error // function to call on every new connection + AcquireTimeout time.Duration // max wait time when all connections are busy (0 means no timeout) } type ConnPool struct { @@ -23,6 +25,7 @@ type ConnPool struct { logLevel int closed bool preparedStatements map[string]*PreparedStatement + acquireTimeout time.Duration } type ConnPoolStat struct { @@ -43,6 +46,10 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { if p.maxConnections < 1 { return nil, errors.New("MaxConnections must be at least 1") } + p.acquireTimeout = config.AcquireTimeout + if p.acquireTimeout < 0 { + return nil, errors.New("AcquireTimeout must be equal to or greater than 0") + } p.afterConnect = config.AfterConnect @@ -77,13 +84,13 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { // Acquire takes exclusive use of a connection until it is released. func (p *ConnPool) Acquire() (*Conn, error) { p.cond.L.Lock() - c, err := p.acquire() + c, err := p.acquire(nil) p.cond.L.Unlock() return c, err } // acquire performs acquision assuming pool is already locked -func (p *ConnPool) acquire() (*Conn, error) { +func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { if p.closed { return nil, errors.New("cannot acquire from closed pool") } @@ -112,12 +119,29 @@ func (p *ConnPool) acquire() (*Conn, error) { p.logger.Warn("All connections in pool are busy - waiting...") } + // Set initial timeout/deadline value. If the method (acquire) happens to + // recursively call itself the deadline should retain its value. + if deadline == nil && p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + // If there is a deadline then start a timeout timer + if deadline != nil { + timer := time.AfterFunc(deadline.Sub(time.Now()), func() { + p.cond.Signal() + }) + defer timer.Stop() + } + // Wait until there is an available connection OR room to create a new connection for len(p.availableConnections) == 0 && len(p.allConnections) == p.maxConnections { + if deadline != nil && time.Now().After(*deadline) { + return nil, errors.New("Timeout: All connections in pool are busy") + } p.cond.Wait() } - return p.acquire() + return p.acquire(deadline) } // Release gives up use of a connection. @@ -307,7 +331,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { return ps, nil } - c, err := p.acquire() + c, err := p.acquire(nil) if err != nil { return nil, err } diff --git a/conn_pool_test.go b/conn_pool_test.go index 6d3154df..959dde41 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -3,9 +3,11 @@ package pgx_test import ( "errors" "fmt" - "github.com/jackc/pgx" "sync" "testing" + "time" + + "github.com/jackc/pgx" ) func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool { @@ -17,6 +19,29 @@ func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool { return pool } +func acquireAllConnections(t *testing.T, pool *pgx.ConnPool, maxConnections int) []*pgx.Conn { + connections := make([]*pgx.Conn, maxConnections) + for i := 0; i < maxConnections; i++ { + var err error + if connections[i], err = pool.Acquire(); err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + } + return connections +} + +func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { + for _, c := range connections { + pool.Release(c) + } +} + +func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.Acquire() + return c, time.Now().Sub(startTime), err +} + func TestNewConnPool(t *testing.T) { t.Parallel() @@ -76,27 +101,14 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { pool := createConnPool(t, maxConnections) defer pool.Close() - acquireAll := func() (connections []*pgx.Conn) { - connections = make([]*pgx.Conn, maxConnections) - for i := 0; i < maxConnections; i++ { - var err error - if connections[i], err = pool.Acquire(); err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - } - return - } - - allConnections := acquireAll() + allConnections := acquireAllConnections(t, pool, maxConnections) for _, c := range allConnections { mustExec(t, c, "create temporary table t(counter integer not null)") mustExec(t, c, "insert into t(counter) values(0);") } - for _, c := range allConnections { - pool.Release(c) - } + releaseAllConnections(pool, allConnections) f := func() { conn, err := pool.Acquire() @@ -121,7 +133,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { // Check that temp table in each connection has been incremented some number of times actualCount := int32(0) - allConnections = acquireAll() + allConnections = acquireAllConnections(t, pool, maxConnections) for _, c := range allConnections { var n int32 @@ -138,8 +150,97 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { t.Error("Wrong number of increments") } - for _, c := range allConnections { - pool.Release(c) + releaseAllConnections(pool, allConnections) +} + +func TestAcquireTimeoutSanity(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + // case 1: default 0 value + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Expected NewConnPool with default config.AcquireTimeout not to fail, instead it failed with '%v'", err) + } + pool.Close() + + // case 2: negative value + config.AcquireTimeout = -1 * time.Second + _, err = pgx.NewConnPool(config) + if err == nil { + t.Fatal("Expected NewConnPool with negative config.AcquireTimeout to fail, instead it did not") + } + + // case 3: positive value + config.AcquireTimeout = 1 * time.Second + pool, err = pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Expected NewConnPool with positive config.AcquireTimeout not to fail, instead it failed with '%v'", err) + } + defer pool.Close() +} + +func TestPoolWithAcquireTimeoutSet(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireWithTimeTaken(pool) + + if err == nil || err.Error() != "Timeout: All connections in pool are busy" { + t.Fatalf("Expected error to be 'Timeout: All connections in pool are busy', instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireWithTimeTaken(pool) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) } }