diff --git a/conn_pool.go b/conn_pool.go index b27074cd..a273e8a3 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -92,6 +92,11 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +// deadlinePassed returns true if the given deadline has passed. +func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { + return deadline != nil && time.Now().After(*deadline) +} + // acquire performs acquision assuming pool is already locked func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { if p.closed { @@ -106,44 +111,74 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { return c, nil } - // No connections are available, but we can create more - if len(p.allConnections) < p.maxConnections { - c, err := p.createConnection() - if err != nil { - return nil, err - } - c.poolResetCount = p.resetCount - p.allConnections = append(p.allConnections, c) - return c, nil - } - - // All connections are in use and we cannot create more - if p.logLevel >= LogLevelWarn { - p.logger.Warn("All connections in pool are busy - waiting...") - } - // Set initial timeout/deadline value. If the method (acquire) happens to // recursively call itself the deadline should retain its value. if deadline == nil && p.acquireTimeout > 0 { tmp := time.Now().Add(p.acquireTimeout) deadline = &tmp } + + // Make sure the deadline (if it is) has not passed yet + if p.deadlinePassed(deadline) { + return nil, errors.New("Timeout: Acquire connection timeout") + } + // If there is a deadline then start a timeout timer + var timer *time.Timer if deadline != nil { - timer := time.AfterFunc(deadline.Sub(time.Now()), func() { - p.cond.Signal() + timer = time.AfterFunc(deadline.Sub(time.Now()), func() { + p.cond.Broadcast() }) defer timer.Stop() } - // Wait until there is an available connection OR room to create a new connection - for len(p.availableConnections) == 0 && len(p.allConnections) == p.maxConnections { - if deadline != nil && time.Now().After(*deadline) { - return nil, errors.New("Timeout: All connections in pool are busy") + // No connections are available, but we can create more + if len(p.allConnections) < p.maxConnections { + // Create a placeholder connection. + placeholderConn := &Conn{} + p.allConnections = append(p.allConnections, placeholderConn) + // Create a new connection. + // Carefull here: createConnectionUnlocked() removes the current lock, + // creates a connection and then locks it back. + c, err := p.createConnectionUnlocked() + // Take the placeholder out of the list of connections. + p.removeFromAllConnections(placeholderConn) + // Make sure create connection did not fail + if err != nil { + return nil, err + } + // If resetCount was updated since we started our connection, or + // there is no room in the list of allConnections + // (invalidateAcquired may remove our placeholder), try to re-acquire + // the connection. + if len(p.allConnections) < p.maxConnections { + // Put the new connection to the list. + c.poolResetCount = p.resetCount + p.allConnections = append(p.allConnections, c) + return c, nil + } + // There is no room for the just created connection. + // Close it and try to re-acquire. + c.Close() + } else { + // All connections are in use and we cannot create more + if p.logLevel >= LogLevelWarn { + p.logger.Warn("All connections in pool are busy - waiting...") + } + + // Wait until there is an available connection OR room to create a new connection + for len(p.availableConnections) == 0 && len(p.allConnections) == p.maxConnections { + if p.deadlinePassed(deadline) { + return nil, errors.New("Timeout: All connections in pool are busy") + } + p.cond.Wait() } - p.cond.Wait() } + // Stop the timer so that we do not spawn it on every acquire call. + if timer != nil { + timer.Stop() + } return p.acquire(deadline) } @@ -173,19 +208,24 @@ func (p *ConnPool) Release(conn *Conn) { if conn.IsAlive() { p.availableConnections = append(p.availableConnections, conn) } else { - ac := p.allConnections - for i, c := range ac { - if conn == c { - ac[i] = ac[len(ac)-1] - p.allConnections = ac[0 : len(ac)-1] - break - } - } + p.removeFromAllConnections(conn) } p.cond.L.Unlock() p.cond.Signal() } +// removeFromAllConnections Removes the given connection from the list. +// It returns true if the connection was found and removed or false otherwise. +func (p *ConnPool) removeFromAllConnections(conn *Conn) bool { + for i, c := range p.allConnections { + if conn == c { + p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...) + return true + } + } + return false +} + // Close ends the use of a connection pool. It prevents any new connections // from being acquired, waits until all acquired connections are released, // then closes all underlying connections. @@ -251,13 +291,41 @@ func (p *ConnPool) createConnection() (*Conn, error) { if err != nil { return nil, err } + return p.afterConnectionCreated(c) +} +// createConnectionUnlocked Removes the current lock, creates a new connection, and +// then locks it back. +// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds. +// And we have a pool with 20 connections in it, and we try to acquire them all at +// startup. +// If it happens that the remote server is not accessible, then the first connection +// in the pool blocks all the others for 3 secs, before it gets the timeout. Then +// connection #2 holds the lock and locks everything for the next 3 secs until it +// gets OpenTimeout err, etc. And the very last 20th connection will fail only after +// 3 * 20 = 60 secs. +// To avoid this we put Connect(p.config) outside of the lock (it is thread safe) +// what would allow us to make all the 20 connection in parallel (more or less). +func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { + p.cond.L.Unlock() + c, err := Connect(p.config) + p.cond.L.Lock() + + if err != nil { + return nil, err + } + return p.afterConnectionCreated(c) +} + +// afterConnectionCreated executes (if it is) afterConnect() callback and prepares +// all the known statements for the new connection. +func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { p.pgTypes = c.PgTypes p.pgsql_af_inet = c.pgsql_af_inet p.pgsql_af_inet6 = c.pgsql_af_inet6 if p.afterConnect != nil { - err = p.afterConnect(c) + err := p.afterConnect(c) if err != nil { c.die(err) return nil, err @@ -358,9 +426,16 @@ func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*Prepare return nil, err } - ps, err := c.PrepareEx(name, sql, opts) - p.availableConnections = append(p.availableConnections, c) + + // Double check that the statement was not prepared by someone else + // while we were acquiring the connection (since acquire is not fully + // blocking now, see createConnectionUnlocked()) + if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + + ps, err := c.PrepareEx(name, sql, opts) if err != nil { return nil, err } diff --git a/conn_pool_private_test.go b/conn_pool_private_test.go new file mode 100644 index 00000000..ef0ec1de --- /dev/null +++ b/conn_pool_private_test.go @@ -0,0 +1,44 @@ +package pgx + +import ( + "testing" +) + +func compareConnSlices(slice1, slice2 []*Conn) bool { + if len(slice1) != len(slice2) { + return false + } + for i, c := range slice1 { + if c != slice2[i] { + return false + } + } + return true +} + +func TestConnPoolRemoveFromAllConnections(t *testing.T) { + t.Parallel() + pool := ConnPool{} + conn1 := &Conn{} + conn2 := &Conn{} + conn3 := &Conn{} + + // First element + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn1) + if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) { + t.Fatal("First element test failed") + } + // Element somewhere in the middle + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn2) + if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) { + t.Fatal("Middle element test failed") + } + // Last element + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn3) + if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) { + t.Fatal("Last element test failed") + } +} diff --git a/conn_pool_test.go b/conn_pool_test.go index 5c8920a2..ac2e36ef 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "errors" "fmt" + "net" "sync" "testing" "time" @@ -153,6 +154,54 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { releaseAllConnections(pool, allConnections) } +func TestPoolNonBlockingConections(t *testing.T) { + t.Parallel() + + maxConnections := 5 + openTimeout := 1 * time.Second + dialer := net.Dialer{ + Timeout: openTimeout, + } + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: maxConnections, + } + config.ConnConfig.Dial = dialer.Dial + // We need a server that would silently DROP all incoming requests. + // P.S. I bet there is something better than microsoft.com that does this... + config.Host = "microsoft.com" + + pool, err := pgx.NewConnPool(config) + if err == nil { + t.Fatalf("Expected NewConnPool not to fail, instead it failed with") + } + + var wg sync.WaitGroup + wg.Add(maxConnections) + + startedAt := time.Now() + for i := 0; i < maxConnections; i++ { + go func() { + _, err := pool.Acquire() + wg.Done() + if err == nil { + t.Fatal("Acquire() expected to fail but it did not") + } + }() + } + wg.Wait() + + // Prior to createConnectionUnlocked() use the test took + // maxConnections * openTimeout seconds to complete. + // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. + timeTaken := time.Now().Sub(startedAt) + if timeTaken > openTimeout+1*time.Second { + t.Fatalf("Expected all Aquire() to run in paralles and take about %v, instead it took '%v'", openTimeout, timeTaken) + } + + defer pool.Close() +} + func TestAcquireTimeoutSanity(t *testing.T) { t.Parallel()