From f004f0802c86c034d195998bfc855334266af2c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Dec 2015 13:15:14 -0600 Subject: [PATCH] Add ConnPool.Reset method refs #110 --- conn.go | 1 + conn_pool.go | 27 +++++++++++++++++++++++++++ conn_pool_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ stress_test.go | 1 + 4 files changed, 76 insertions(+) diff --git a/conn.go b/conn.go index 5bfb1813..66eb9fc8 100644 --- a/conn.go +++ b/conn.go @@ -65,6 +65,7 @@ type Conn struct { pgsql_af_inet byte pgsql_af_inet6 byte busy bool + poolResetCount int } type PreparedStatement struct { diff --git a/conn_pool.go b/conn_pool.go index 06c6c22a..6eb489e7 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -17,6 +17,7 @@ type ConnPool struct { cond *sync.Cond config ConnConfig // config used when establishing connection maxConnections int + resetCount int afterConnect func(*Conn) error logger Logger logLevel int @@ -83,6 +84,7 @@ func (p *ConnPool) Acquire() (c *Conn, err error) { // A connection is available if len(p.availableConnections) > 0 { c = p.availableConnections[len(p.availableConnections)-1] + c.poolResetCount = p.resetCount p.availableConnections = p.availableConnections[:len(p.availableConnections)-1] return } @@ -93,6 +95,7 @@ func (p *ConnPool) Acquire() (c *Conn, err error) { if err != nil { return } + c.poolResetCount = p.resetCount p.allConnections = append(p.allConnections, c) return } @@ -108,6 +111,7 @@ func (p *ConnPool) Acquire() (c *Conn, err error) { } c = p.availableConnections[len(p.availableConnections)-1] + c.poolResetCount = p.resetCount p.availableConnections = p.availableConnections[:len(p.availableConnections)-1] return @@ -128,6 +132,14 @@ func (p *ConnPool) Release(conn *Conn) { conn.notifications = nil p.cond.L.Lock() + + if conn.poolResetCount != p.resetCount { + conn.Close() + p.cond.L.Unlock() + p.cond.Signal() + return + } + if conn.IsAlive() { p.availableConnections = append(p.availableConnections, conn) } else { @@ -165,6 +177,21 @@ func (p *ConnPool) Close() { } } +// Reset closes all open connections, but leaves the pool open. It is intended +// for use when an error is detected that would disrupt all connections (such as +// a network interruption or a server state change). +// +// It is safe to reset a pool while connections are checked out. Those +// connections will be closed when they are returned to the pool. +func (p *ConnPool) Reset() { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + p.resetCount++ + p.allConnections = make([]*Conn, 0, p.maxConnections) + p.availableConnections = make([]*Conn, 0, p.maxConnections) +} + // Stat returns connection pool statistics func (p *ConnPool) Stat() (s ConnPoolStat) { p.cond.L.Lock() diff --git a/conn_pool_test.go b/conn_pool_test.go index c83945f2..6e1134e4 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -295,6 +295,53 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } } +func TestConnPoolReset(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 5) + defer pool.Close() + + inProgressRows := []*pgx.Rows{} + + // Start some queries and reset pool while they are in progress + for i := 0; i < 10; i++ { + rows, err := pool.Query("select generate_series(1,5)::bigint") + if err != nil { + t.Fatal(err) + } + + inProgressRows = append(inProgressRows, rows) + pool.Reset() + } + + // Check that the queries are completed + for _, rows := range inProgressRows { + var expectedN int64 + + for rows.Next() { + expectedN++ + var n int64 + err := rows.Scan(&n) + if err != nil { + t.Fatal(err) + } + if expectedN != n { + t.Fatalf("Expected n to be %d, but it was %d", expectedN, n) + } + } + + if err := rows.Err(); err != nil { + t.Fatal(err) + } + } + + // pool should be in fresh state due to previous reset + stats := pool.Stat() + if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } +} + func TestConnPoolTransaction(t *testing.T) { t.Parallel() diff --git a/stress_test.go b/stress_test.go index 503f151e..b40827fb 100644 --- a/stress_test.go +++ b/stress_test.go @@ -41,6 +41,7 @@ func TestStressConnPool(t *testing.T) { {"txMultipleQueries", txMultipleQueries}, {"notify", notify}, {"listenAndPoolUnlistens", listenAndPoolUnlistens}, + {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, } actionCount := 5000