From 622ff142cabcdcf0ecbf7f868b67dd91917bc6b0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Jul 2013 07:52:05 -0500 Subject: [PATCH] Add error to *ConnectionPool.Acquire return --- bench_test.go | 5 ++++- connection_pool.go | 46 +++++++++++++++++++++++++++++------------ connection_pool_test.go | 19 +++++++++++------ 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/bench_test.go b/bench_test.go index cd0654ce..01042886 100644 --- a/bench_test.go +++ b/bench_test.go @@ -610,7 +610,10 @@ func BenchmarkConnectionPool(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - conn := pool.Acquire() + var conn *pgx.Connection + if conn, err = pool.Acquire(); err != nil { + b.Fatalf("Unable to acquire connection: %v", err) + } pool.Release(conn) } diff --git a/connection_pool.go b/connection_pool.go index 5e101656..9a040030 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -8,7 +8,8 @@ type ConnectionPoolOptions struct { type ConnectionPool struct { connectionChannel chan *Connection parameters ConnectionParameters // parameters used when establishing connection - options ConnectionPoolOptions + maxConnections int + afterConnect func(*Connection) error } // NewConnectionPool creates a new ConnectionPool. parameters are passed through to @@ -18,9 +19,10 @@ func NewConnectionPool(parameters ConnectionParameters, options ConnectionPoolOp p.connectionChannel = make(chan *Connection, options.MaxConnections) p.parameters = parameters - p.options = options + p.maxConnections = options.MaxConnections + p.afterConnect = options.AfterConnect - for i := 0; i < p.options.MaxConnections; i++ { + for i := 0; i < p.maxConnections; i++ { var c *Connection c, err = p.createConnection() if err != nil { @@ -33,7 +35,7 @@ func NewConnectionPool(parameters ConnectionParameters, options ConnectionPoolOp } // Acquire takes exclusive use of a connection until it is released. -func (p *ConnectionPool) Acquire() (c *Connection) { +func (p *ConnectionPool) Acquire() (c *Connection, err error) { c = <-p.connectionChannel return } @@ -48,7 +50,7 @@ func (p *ConnectionPool) Release(c *Connection) { // Close ends the use of a connection by closing all underlying connections. func (p *ConnectionPool) Close() { - for i := 0; i < p.options.MaxConnections; i++ { + for i := 0; i < p.maxConnections; i++ { c := <-p.connectionChannel _ = c.Close() } @@ -59,8 +61,8 @@ func (p *ConnectionPool) createConnection() (c *Connection, err error) { if err != nil { return } - if p.options.AfterConnect != nil { - err = p.options.AfterConnect(c) + if p.afterConnect != nil { + err = p.afterConnect(c) if err != nil { return } @@ -70,7 +72,10 @@ func (p *ConnectionPool) createConnection() (c *Connection, err error) { // SelectFunc acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnectionPool) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) { - c := p.Acquire() + var c *Connection + if c, err = p.Acquire(); err != nil { + return + } defer p.Release(c) return c.SelectFunc(sql, onDataRow, arguments...) @@ -78,7 +83,10 @@ func (p *ConnectionPool) SelectFunc(sql string, onDataRow func(*DataRowReader) e // SelectRows acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnectionPool) SelectRows(sql string, arguments ...interface{}) (rows []map[string]interface{}, err error) { - c := p.Acquire() + var c *Connection + if c, err = p.Acquire(); err != nil { + return + } defer p.Release(c) return c.SelectRows(sql, arguments...) @@ -86,7 +94,10 @@ func (p *ConnectionPool) SelectRows(sql string, arguments ...interface{}) (rows // SelectRow acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnectionPool) SelectRow(sql string, arguments ...interface{}) (row map[string]interface{}, err error) { - c := p.Acquire() + var c *Connection + if c, err = p.Acquire(); err != nil { + return + } defer p.Release(c) return c.SelectRow(sql, arguments...) @@ -94,7 +105,10 @@ func (p *ConnectionPool) SelectRow(sql string, arguments ...interface{}) (row ma // SelectValue acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnectionPool) SelectValue(sql string, arguments ...interface{}) (v interface{}, err error) { - c := p.Acquire() + var c *Connection + if c, err = p.Acquire(); err != nil { + return + } defer p.Release(c) return c.SelectValue(sql, arguments...) @@ -102,7 +116,10 @@ func (p *ConnectionPool) SelectValue(sql string, arguments ...interface{}) (v in // SelectValues acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnectionPool) SelectValues(sql string, arguments ...interface{}) (values []interface{}, err error) { - c := p.Acquire() + var c *Connection + if c, err = p.Acquire(); err != nil { + return + } defer p.Release(c) return c.SelectValues(sql, arguments...) @@ -110,7 +127,10 @@ func (p *ConnectionPool) SelectValues(sql string, arguments ...interface{}) (val // Execute acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnectionPool) Execute(sql string, arguments ...interface{}) (commandTag string, err error) { - c := p.Acquire() + var c *Connection + if c, err = p.Acquire(); err != nil { + return + } defer p.Release(c) return c.Execute(sql, arguments...) diff --git a/connection_pool_test.go b/connection_pool_test.go index 23e6b5a1..597728eb 100644 --- a/connection_pool_test.go +++ b/connection_pool_test.go @@ -57,7 +57,10 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { acquireAll := func() (connections []*pgx.Connection) { connections = make([]*pgx.Connection, maxConnections) for i := 0; i < maxConnections; i++ { - connections[i] = pool.Acquire() + var err error + if connections[i], err = pool.Acquire(); err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } } return } @@ -74,8 +77,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { } f := func() { - var err error - conn := pool.Acquire() + conn, err := pool.Acquire() if err != nil { t.Fatal("Unable to acquire connection") } @@ -123,8 +125,10 @@ func TestPoolReleaseWithTransactions(t *testing.T) { pool := createConnectionPool(t, 1) defer pool.Close() - var err error - conn := pool.Acquire() + conn, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } mustExecute(t, conn, "begin") if _, err = conn.Execute("select"); err == nil { t.Fatal("Did not receive expected error") @@ -139,7 +143,10 @@ func TestPoolReleaseWithTransactions(t *testing.T) { t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus) } - conn = pool.Acquire() + conn, err = pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } mustExecute(t, conn, "begin") if conn.TxStatus != 'T' { t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus)