diff --git a/conn_pool.go b/conn_pool.go index 3be0c23a..6a02382e 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -19,6 +19,7 @@ type ConnPool struct { maxConnections int afterConnect func(*Conn) error logger Logger + closed bool } type ConnPoolStat struct { @@ -68,6 +69,10 @@ func (p *ConnPool) Acquire() (c *Conn, err error) { p.cond.L.Lock() defer p.cond.L.Unlock() + if p.closed { + return nil, errors.New("cannot acquire from closed pool") + } + // A connection is available if len(p.availableConnections) > 0 { c = p.availableConnections[len(p.availableConnections)-1] @@ -122,13 +127,25 @@ func (p *ConnPool) Release(conn *Conn) { p.cond.Signal() } -// Close ends the use of a connection pool by closing all underlying connections. +// Close ends the use of a connection pool. It prevents any new connections +// from being acquired, waits until all acquired connections are released, +// then closes all underlying connections. func (p *ConnPool) Close() { - for i := 0; i < p.maxConnections; i++ { - if c, err := p.Acquire(); err == nil { - _ = c.Close() + p.cond.L.Lock() + defer p.cond.L.Unlock() + + p.closed = true + + // Wait until all connections are released + if len(p.availableConnections) != len(p.allConnections) { + for len(p.availableConnections) != len(p.allConnections) { + p.cond.Wait() } } + + for _, c := range p.allConnections { + _ = c.Close() + } } // Stat returns connection pool statistics diff --git a/conn_pool_test.go b/conn_pool_test.go index 461ffc91..6164ff44 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -416,8 +416,12 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) { pool := createConnPool(t, 10) defer pool.Close() - for i := 0; i < 100; i++ { + n := 100 + done := make(chan bool) + + for i := 0; i < n; i++ { go func() { + defer func() { done <- true }() var rowCount int32 rows, err := pool.Query("select generate_series(1,$1)", 1000) @@ -447,6 +451,10 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) { } }() } + + for i := 0; i < n; i++ { + <-done + } } func TestConnPoolQueryRow(t *testing.T) {