diff --git a/conn_pool.go b/conn_pool.go index 612b2ce5..6ee767b2 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -121,9 +121,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { } // If there is a deadline then start a timeout timer + var timer *time.Timer if deadline != nil { - timer := time.AfterFunc(deadline.Sub(time.Now()), func() { - p.cond.Signal() + timer = time.AfterFunc(deadline.Sub(time.Now()), func() { + p.cond.Broadcast() }) defer timer.Stop() } @@ -147,29 +148,34 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // there is no room in the list of allConnections // (invalidateAcquired may remove our placeholder), try to re-acquire // the connection. - if len(p.allConnections) >= p.maxConnections { - c.Close() - return p.acquire(deadline) + if len(p.allConnections) < p.maxConnections { + // Put the new connection to the list. + c.poolResetCount = p.resetCount + p.allConnections = append(p.allConnections, c) + return c, nil } - // Put the new connection to the list. - c.poolResetCount = p.resetCount - p.allConnections = append(p.allConnections, c) - return c, nil - } - - // All connections are in use and we cannot create more - if p.logLevel >= LogLevelWarn { - p.logger.Warn("All connections in pool are busy - waiting...") - } - - // Wait until there is an available connection OR room to create a new connection - for len(p.availableConnections) == 0 && len(p.allConnections) == p.maxConnections { - if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: All connections in pool are busy") + // There is no room for the just created connection. + // Close it and try to re-acquire. + c.Close() + } else { + // All connections are in use and we cannot create more + if p.logLevel >= LogLevelWarn { + p.logger.Warn("All connections in pool are busy - waiting...") + } + + // Wait until there is an available connection OR room to create a new connection + for len(p.availableConnections) == 0 && len(p.allConnections) == p.maxConnections { + if p.deadlinePassed(deadline) { + return nil, errors.New("Timeout: All connections in pool are busy") + } + p.cond.Wait() } - p.cond.Wait() } + // Stop the timer so that we do not spawn it on every acquire call. + if timer != nil { + timer.Stop() + } return p.acquire(deadline) } @@ -311,9 +317,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // afterConnectionCreated executes (if it is) afterConnect() callback and prepares // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { - var err error if p.afterConnect != nil { - err = p.afterConnect(c) + err := p.afterConnect(c) if err != nil { c.die(err) return nil, err diff --git a/conn_pool_private_test.go b/conn_pool_private_test.go new file mode 100644 index 00000000..ef0ec1de --- /dev/null +++ b/conn_pool_private_test.go @@ -0,0 +1,44 @@ +package pgx + +import ( + "testing" +) + +func compareConnSlices(slice1, slice2 []*Conn) bool { + if len(slice1) != len(slice2) { + return false + } + for i, c := range slice1 { + if c != slice2[i] { + return false + } + } + return true +} + +func TestConnPoolRemoveFromAllConnections(t *testing.T) { + t.Parallel() + pool := ConnPool{} + conn1 := &Conn{} + conn2 := &Conn{} + conn3 := &Conn{} + + // First element + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn1) + if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) { + t.Fatal("First element test failed") + } + // Element somewhere in the middle + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn2) + if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) { + t.Fatal("Middle element test failed") + } + // Last element + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn3) + if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) { + t.Fatal("Last element test failed") + } +} diff --git a/conn_pool_test.go b/conn_pool_test.go index 3c4acc1f..c04f0d47 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -176,23 +176,20 @@ func TestPoolNonBlockingConections(t *testing.T) { t.Fatalf("Expected NewConnPool not to fail, instead it failed with") } - done := make(chan bool) + var wg sync.WaitGroup + wg.Add(maxConnections) startedAt := time.Now() for i := 0; i < maxConnections; i++ { go func() { _, err := pool.Acquire() - done <- true + wg.Done() if err == nil { t.Fatal("Acquire() expected to fail but it did not") } }() } - - // Wait for all the channels to succeedd - for i := 0; i < maxConnections; i++ { - <-done - } + wg.Wait() // Prior to createConnectionUnlocked() use the test took // maxConnections * openTimeout seconds to complete.