diff --git a/conn_pool.go b/conn_pool.go index 23c5c9db..2ac877f0 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -1,6 +1,7 @@ package pgx import ( + "errors" log "gopkg.in/inconshreveable/log15.v2" "io" "sync" @@ -8,7 +9,7 @@ import ( type ConnPoolConfig struct { ConnConfig - MaxConnections int // max simultaneous connections to use + MaxConnections int // max simultaneous connections to use, default 5, must be at least 2 AfterConnect func(*Conn) error } @@ -34,6 +35,13 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p = new(ConnPool) p.config = config.ConnConfig p.maxConnections = config.MaxConnections + if p.maxConnections == 0 { + p.maxConnections = 5 + } + if p.maxConnections < 2 { + return nil, errors.New("MaxConnections must be at least 2") + } + p.afterConnect = config.AfterConnect if config.Logger != nil { p.logger = config.Logger diff --git a/conn_pool_test.go b/conn_pool_test.go index d0ad6c14..bf33d1a9 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -52,6 +52,35 @@ func TestNewConnPool(t *testing.T) { } } +func TestNewConnPoolDefaultsTo5MaxConnections(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatal("Unable to establish connection pool") + } + defer pool.Close() + + if n := pool.MaxConnectionCount(); n != 5 { + t.Fatalf("Expected pool to default to 5 max connections, but it was %d", n) + } +} + +func TestNewConnPoolMaxConnectionsCannotBeLessThan2(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 1} + pool, err := pgx.NewConnPool(config) + if err == nil { + pool.Close() + t.Fatal(`Expected NewConnPool to fail with "MaxConnections must be at least 2" error, but it succeeded`) + } + if err.Error() != "MaxConnections must be at least 2" { + t.Fatalf(`Expected NewConnPool to fail with "MaxConnections must be at least 2" error, but it failed with %v`, err) + } +} + func TestPoolAcquireAndReleaseCycle(t *testing.T) { t.Parallel() @@ -131,7 +160,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { func TestPoolReleaseWithTransactions(t *testing.T) { t.Parallel() - pool := createConnPool(t, 1) + pool := createConnPool(t, 2) defer pool.Close() conn, err := pool.Acquire() @@ -274,7 +303,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { func TestPoolTransaction(t *testing.T) { t.Parallel() - pool := createConnPool(t, 1) + pool := createConnPool(t, 2) defer pool.Close() committed, err := pool.Transaction(func(conn *pgx.Conn) bool { @@ -329,7 +358,7 @@ func TestPoolTransaction(t *testing.T) { func TestPoolTransactionIso(t *testing.T) { t.Parallel() - pool := createConnPool(t, 1) + pool := createConnPool(t, 2) defer pool.Close() committed, err := pool.TransactionIso("serializable", func(conn *pgx.Conn) bool {