From e55a5ebccfcdba5af3f7b3bd2e5c3b52235f1453 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 23 Jul 2013 17:46:38 -0500 Subject: [PATCH] Add AfterConnect callback for ConnectionPool --- connection_pool.go | 30 ++++++++++++++++++++---------- connection_pool_test.go | 29 +++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/connection_pool.go b/connection_pool.go index 976494d2..79626b8d 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -1,27 +1,37 @@ package pgx +type ConnectionPoolOptions struct { + MaxConnections int // max simultaneous connections to use (currently all are immediately connected) + AfterConnect func(*Connection) error +} + type ConnectionPool struct { connectionChannel chan *Connection - parameters ConnectionParameters // options used when establishing connection - MaxConnections int + parameters ConnectionParameters // parameters used when establishing connection + options ConnectionPoolOptions } -// NewConnectionPool creates a new ConnectionPool. options are passed through to -// Connect directly. MaxConnections is max simultaneous connections to use -// (currently all are immediately connected). -func NewConnectionPool(parameters ConnectionParameters, MaxConnections int) (p *ConnectionPool, err error) { +// NewConnectionPool creates a new ConnectionPool. parameters are passed through to +// Connect directly. +func NewConnectionPool(parameters ConnectionParameters, options ConnectionPoolOptions) (p *ConnectionPool, err error) { p = new(ConnectionPool) - p.connectionChannel = make(chan *Connection, MaxConnections) - p.MaxConnections = MaxConnections + p.connectionChannel = make(chan *Connection, options.MaxConnections) p.parameters = parameters + p.options = options - for i := 0; i < p.MaxConnections; i++ { + for i := 0; i < p.options.MaxConnections; i++ { var c *Connection c, err = Connect(p.parameters) if err != nil { return } + if p.options.AfterConnect != nil { + err = p.options.AfterConnect(c) + if err != nil { + return + } + } p.connectionChannel <- c } @@ -44,7 +54,7 @@ func (p *ConnectionPool) Release(c *Connection) { // Close ends the use of a connection by closing all underlying connections. func (p *ConnectionPool) Close() { - for i := 0; i < p.MaxConnections; i++ { + for i := 0; i < p.options.MaxConnections; i++ { c := <-p.connectionChannel _ = c.Close() } diff --git a/connection_pool_test.go b/connection_pool_test.go index ace28259..23e6b5a1 100644 --- a/connection_pool_test.go +++ b/connection_pool_test.go @@ -1,13 +1,15 @@ package pgx_test import ( + "errors" "fmt" "github.com/JackC/pgx" "testing" ) func createConnectionPool(t *testing.T, maxConnections int) *pgx.ConnectionPool { - pool, err := pgx.NewConnectionPool(*defaultConnectionParameters, maxConnections) + options := pgx.ConnectionPoolOptions{MaxConnections: maxConnections} + pool, err := pgx.NewConnectionPool(*defaultConnectionParameters, options) if err != nil { t.Fatalf("Unable to create connection pool: %v", err) } @@ -15,14 +17,33 @@ func createConnectionPool(t *testing.T, maxConnections int) *pgx.ConnectionPool } func TestNewConnectionPool(t *testing.T) { - pool, err := pgx.NewConnectionPool(*defaultConnectionParameters, 5) + var numCallbacks int + afterConnect := func(c *pgx.Connection) error { + numCallbacks++ + return nil + } + + options := pgx.ConnectionPoolOptions{MaxConnections: 2, AfterConnect: afterConnect} + pool, err := pgx.NewConnectionPool(*defaultConnectionParameters, options) if err != nil { t.Fatal("Unable to establish connection pool") } defer pool.Close() - if pool.MaxConnections != 5 { - t.Error("Wrong maxConnections") + if numCallbacks != 2 { + t.Errorf("Expected AfterConnect callback to fire %v times but only fired %v times", numCallbacks, numCallbacks) + } + + // Pool creation returns an error if any AfterConnect callback does + errAfterConnect := errors.New("Some error") + afterConnect = func(c *pgx.Connection) error { + return errAfterConnect + } + + options = pgx.ConnectionPoolOptions{MaxConnections: 2, AfterConnect: afterConnect} + pool, err = pgx.NewConnectionPool(*defaultConnectionParameters, options) + if err != errAfterConnect { + t.Errorf("Expected errAfterConnect but received unexpected: %v", err) } }