diff --git a/conn_pool.go b/conn_pool.go index e6b44934..612b2ce5 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -89,6 +89,11 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +// deadlinePassed returns true if the given deadline has passed. +func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { + return deadline != nil && time.Now().After(*deadline) +} + // acquire performs acquision assuming pool is already locked func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { if p.closed { @@ -103,12 +108,50 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { return c, nil } + // Set initial timeout/deadline value. If the method (acquire) happens to + // recursively call itself the deadline should retain its value. + if deadline == nil && p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + // Make sure the deadline (if it is) has not passed yet + if p.deadlinePassed(deadline) { + return nil, errors.New("Timeout: Acquire connection timeout") + } + + // If there is a deadline then start a timeout timer + if deadline != nil { + timer := time.AfterFunc(deadline.Sub(time.Now()), func() { + p.cond.Signal() + }) + defer timer.Stop() + } + // No connections are available, but we can create more if len(p.allConnections) < p.maxConnections { - c, err := p.createConnection() + // Create a placeholder connection. + placeholderConn := &Conn{} + p.allConnections = append(p.allConnections, placeholderConn) + // Create a new connection. + // Carefull here: createConnectionUnlocked() removes the current lock, + // creates a connection and then locks it back. + c, err := p.createConnectionUnlocked() + // Take the placeholder out of the list of connections. + p.removeFromAllConnections(placeholderConn) + // Make sure create connection did not fail if err != nil { return nil, err } + // If resetCount was updated since we started our connection, or + // there is no room in the list of allConnections + // (invalidateAcquired may remove our placeholder), try to re-acquire + // the connection. + if len(p.allConnections) >= p.maxConnections { + c.Close() + return p.acquire(deadline) + } + // Put the new connection to the list. c.poolResetCount = p.resetCount p.allConnections = append(p.allConnections, c) return c, nil @@ -119,23 +162,9 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { p.logger.Warn("All connections in pool are busy - waiting...") } - // Set initial timeout/deadline value. If the method (acquire) happens to - // recursively call itself the deadline should retain its value. - if deadline == nil && p.acquireTimeout > 0 { - tmp := time.Now().Add(p.acquireTimeout) - deadline = &tmp - } - // If there is a deadline then start a timeout timer - if deadline != nil { - timer := time.AfterFunc(deadline.Sub(time.Now()), func() { - p.cond.Signal() - }) - defer timer.Stop() - } - // Wait until there is an available connection OR room to create a new connection for len(p.availableConnections) == 0 && len(p.allConnections) == p.maxConnections { - if deadline != nil && time.Now().After(*deadline) { + if p.deadlinePassed(deadline) { return nil, errors.New("Timeout: All connections in pool are busy") } p.cond.Wait() @@ -170,19 +199,24 @@ func (p *ConnPool) Release(conn *Conn) { if conn.IsAlive() { p.availableConnections = append(p.availableConnections, conn) } else { - ac := p.allConnections - for i, c := range ac { - if conn == c { - ac[i] = ac[len(ac)-1] - p.allConnections = ac[0 : len(ac)-1] - break - } - } + p.removeFromAllConnections(conn) } p.cond.L.Unlock() p.cond.Signal() } +// removeFromAllConnections Removes the given connection from the list. +// It returns true if the connection was found and removed or false otherwise. +func (p *ConnPool) removeFromAllConnections(conn *Conn) bool { + for i, c := range p.allConnections { + if conn == c { + p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...) + return true + } + } + return false +} + // Close ends the use of a connection pool. It prevents any new connections // from being acquired, waits until all acquired connections are released, // then closes all underlying connections. @@ -248,7 +282,36 @@ func (p *ConnPool) createConnection() (*Conn, error) { if err != nil { return nil, err } + return p.afterConnectionCreated(c) +} +// createConnectionUnlocked Removes the current lock, creates a new connection, and +// then locks it back. +// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds. +// And we have a pool with 20 connections in it, and we try to acquire them all at +// startup. +// If it happens that the remote server is not accessible, then the first connection +// in the pool blocks all the others for 3 secs, before it gets the timeout. Then +// connection #2 holds the lock and locks everything for the next 3 secs until it +// gets OpenTimeout err, etc. And the very last 20th connection will fail only after +// 3 * 20 = 60 secs. +// To avoid this we put Connect(p.config) outside of the lock (it is thread safe) +// what would allow us to make all the 20 connection in parallel (more or less). +func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { + p.cond.L.Unlock() + c, err := Connect(p.config) + p.cond.L.Lock() + + if err != nil { + return nil, err + } + return p.afterConnectionCreated(c) +} + +// afterConnectionCreated executes (if it is) afterConnect() callback and prepares +// all the known statements for the new connection. +func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { + var err error if p.afterConnect != nil { err = p.afterConnect(c) if err != nil { @@ -335,8 +398,16 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { if err != nil { return nil, err } - ps, err := c.Prepare(name, sql) p.availableConnections = append(p.availableConnections, c) + + // Double check that the statement was not prepared by someone else + // while we were acquiring the connection (since acquire is not fully + // blocking now, see createConnectionUnlocked()) + if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + + ps, err := c.Prepare(name, sql) if err != nil { return nil, err } diff --git a/conn_pool_test.go b/conn_pool_test.go index 959dde41..3c4acc1f 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "errors" "fmt" + "net" "sync" "testing" "time" @@ -153,6 +154,57 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { releaseAllConnections(pool, allConnections) } +func TestPoolNonBlockingConections(t *testing.T) { + t.Parallel() + + maxConnections := 5 + openTimeout := 1 * time.Second + dialer := net.Dialer{ + Timeout: openTimeout, + } + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: maxConnections, + } + config.ConnConfig.Dial = dialer.Dial + // We need a server that would silently DROP all incoming requests. + // P.S. I bet there is something better than microsoft.com that does this... + config.Host = "microsoft.com" + + pool, err := pgx.NewConnPool(config) + if err == nil { + t.Fatalf("Expected NewConnPool not to fail, instead it failed with") + } + + done := make(chan bool) + + startedAt := time.Now() + for i := 0; i < maxConnections; i++ { + go func() { + _, err := pool.Acquire() + done <- true + if err == nil { + t.Fatal("Acquire() expected to fail but it did not") + } + }() + } + + // Wait for all the channels to succeedd + for i := 0; i < maxConnections; i++ { + <-done + } + + // Prior to createConnectionUnlocked() use the test took + // maxConnections * openTimeout seconds to complete. + // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. + timeTaken := time.Now().Sub(startedAt) + if timeTaken > openTimeout+1*time.Second { + t.Fatalf("Expected all Aquire() to run in paralles and take about %v, instead it took '%v'", openTimeout, timeTaken) + } + + defer pool.Close() +} + func TestAcquireTimeoutSanity(t *testing.T) { t.Parallel()