diff --git a/CHANGELOG.md b/CHANGELOG.md index bb575b78..1365e14a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ standard database/sql package such as * Rows.Scan errors now include which argument caused error * Add Encode() to allow custom Encoders to reuse internal encoding functionality * Add Decode() to allow customer Decoders to reuse internal decoding functionality +* Add ConnPool.Prepare method +* Add ConnPool.Deallocate method ## Performance diff --git a/conn_pool.go b/conn_pool.go index b0aa4278..749f9dfe 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -22,6 +22,7 @@ type ConnPool struct { logger Logger logLevel int closed bool + preparedStatements map[string]*PreparedStatement } type ConnPoolStat struct { @@ -58,6 +59,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p.allConnections = make([]*Conn, 0, p.maxConnections) p.availableConnections = make([]*Conn, 0, p.maxConnections) + p.preparedStatements = make(map[string]*PreparedStatement) p.cond = sync.NewCond(new(sync.Mutex)) // Initially establish one connection @@ -193,6 +195,19 @@ func (p *ConnPool) Reset() { p.availableConnections = make([]*Conn, 0, p.maxConnections) } +// invalidateAcquired causes all acquired connections to be closed when released. +// The pool must already be locked. +func (p *ConnPool) invalidateAcquired() { + p.resetCount++ + + for _, c := range p.availableConnections { + c.poolResetCount = p.resetCount + } + + p.allConnections = p.allConnections[:len(p.availableConnections)] + copy(p.allConnections, p.availableConnections) +} + // Stat returns connection pool statistics func (p *ConnPool) Stat() (s ConnPoolStat) { p.cond.L.Lock() @@ -204,18 +219,28 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { return } -func (p *ConnPool) createConnection() (c *Conn, err error) { - c, err = Connect(p.config) +func (p *ConnPool) createConnection() (*Conn, error) { + c, err := Connect(p.config) if err != nil { - return + return nil, err } + if p.afterConnect != nil { err = p.afterConnect(c) if err != nil { - return + c.die(err) + return nil, err } } - return + + for _, ps := range p.preparedStatements { + if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { + c.die(err) + return nil, err + } + } + + return c, nil } // Exec acquires a connection, delegates the call to that connection, and releases the connection @@ -263,6 +288,64 @@ func (p *ConnPool) Begin() (*Tx, error) { return p.BeginIso("") } +// Prepare creates a prepared statement on a connection in the pool to test the +// statement is valid. If it succeeds all connections accessed through the pool +// will have the statement available. +// +// Prepare creates a prepared statement with name and sql. sql can contain +// placeholders for bound parameters. These placeholders are referenced +// positional as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with +// the same name and sql arguments. This allows a code path to Prepare and +// Query/Exec without concern for if the statement has already been prepared. +func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + + c, err := p.acquire() + if err != nil { + return nil, err + } + ps, err := c.Prepare(name, sql) + p.availableConnections = append(p.availableConnections, c) + if err != nil { + return nil, err + } + + for _, c := range p.availableConnections { + _, err := c.Prepare(name, sql) + if err != nil { + return nil, err + } + } + + p.invalidateAcquired() + p.preparedStatements[name] = ps + + return ps, err +} + +// Deallocate releases a prepared statement from all connections in the pool. +func (p *ConnPool) Deallocate(name string) (err error) { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + for _, c := range p.availableConnections { + if err := c.Deallocate(name); err != nil { + return err + } + } + + p.invalidateAcquired() + + return nil +} + // BeginIso acquires a connection and begins a transaction in isolation mode iso // on it. When the transaction is closed the connection will be automatically // released. diff --git a/conn_pool_test.go b/conn_pool_test.go index 636569e4..6d3154df 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -599,3 +599,109 @@ func TestConnPoolExec(t *testing.T) { t.Errorf("Unexpected results from Exec: %v", results) } } + +func TestConnPoolPrepare(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + _, err := pool.Prepare("test", "select $1::varchar") + if err != nil { + t.Fatalf("Unable to prepare statement: %v", err) + } + + var s string + err = pool.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Errorf("Executing prepared statement failed: %v", err) + } + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) + } + + err = pool.Deallocate("test") + if err != nil { + t.Errorf("Deallocate failed: %v", err) + } + + err = pool.QueryRow("test", "hello").Scan(&s) + if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") { + t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) + } +} + +func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + testPreparedStatement := func(db queryRower, desc string) { + var s string + err := db.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Fatalf("%s. Executing prepared statement failed: %v", desc, err) + } + + if s != "hello" { + t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s) + } + } + + newReleaseOnce := func(c *pgx.Conn) func() { + var once sync.Once + return func() { + once.Do(func() { pool.Release(c) }) + } + } + + c1, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + c1Release := newReleaseOnce(c1) + defer c1Release() + + _, err = pool.Prepare("test", "select $1::varchar") + if err != nil { + t.Fatalf("Unable to prepare statement: %v", err) + } + + testPreparedStatement(pool, "pool") + + c1Release() + + c2, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + c2Release := newReleaseOnce(c2) + defer c2Release() + + // This conn will not be available and will be connection at this point + c3, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + c3Release := newReleaseOnce(c3) + defer c3Release() + + testPreparedStatement(c2, "c2") + testPreparedStatement(c3, "c3") + + c2Release() + c3Release() + + err = pool.Deallocate("test") + if err != nil { + t.Errorf("Deallocate failed: %v", err) + } + + var s string + err = pool.QueryRow("test", "hello").Scan(&s) + if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") { + t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) + } +} diff --git a/stress_test.go b/stress_test.go index 67642b3e..150d13c8 100644 --- a/stress_test.go +++ b/stress_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "errors" + "fmt" "math/rand" "testing" "time" @@ -42,6 +43,7 @@ func TestStressConnPool(t *testing.T) { {"notify", notify}, {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, + {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, } var timer *time.Timer @@ -246,6 +248,27 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { return err } +func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error { + psName := fmt.Sprintf("poolPreparedStatement%d", actionNum) + + _, err := pool.Prepare(psName, "select $1::text") + if err != nil { + return err + } + + var s string + err = pool.QueryRow(psName, "hello").Scan(&s) + if err != nil { + return err + } + + if s != "hello" { + return fmt.Errorf("Prepared statement did not return expected value: %v", s) + } + + return pool.Deallocate(psName) +} + func txInsertRollback(pool *pgx.ConnPool, actionNum int) error { tx, err := pool.Begin() if err != nil {