From ec10fdde8b10029bd686e0b8fe74ef343067edd6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 9 Apr 2019 21:32:01 -0500 Subject: [PATCH] Core of new connection pool --- conn.go | 4 + conn_pool_test.go | 163 ---------------------------------- go.mod | 1 + go.sum | 2 + pool/common_test.go | 106 ++++++++++++++++++++++ pool/conn.go | 85 ++++++++++++++++++ pool/conn_test.go | 70 +++++++++++++++ pool/pool.go | 132 +++++++++++++++++++++++++++ pool/pool_test.go | 211 ++++++++++++++++++++++++++++++++++++++++++++ pool/rows.go | 76 ++++++++++++++++ pool/stat.go | 47 ++++++++++ pool/todo.txt | 8 ++ pool/tx.go | 55 ++++++++++++ pool/tx_test.go | 70 +++++++++++++++ private_test.go | 7 -- 15 files changed, 867 insertions(+), 170 deletions(-) create mode 100644 pool/common_test.go create mode 100644 pool/conn.go create mode 100644 pool/conn_test.go create mode 100644 pool/pool.go create mode 100644 pool/pool_test.go create mode 100644 pool/rows.go create mode 100644 pool/stat.go create mode 100644 pool/todo.txt create mode 100644 pool/tx.go create mode 100644 pool/tx_test.go delete mode 100644 private_test.go diff --git a/conn.go b/conn.go index f1939031..29dc924f 100644 --- a/conn.go +++ b/conn.go @@ -461,6 +461,10 @@ func (c *Conn) Close() error { return err } +func (c *Conn) TxStatus() byte { + return c.pgConn.TxStatus +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (c *Conn) ParameterStatus(key string) string { diff --git a/conn_pool_test.go b/conn_pool_test.go index 4d1f2aaf..2cd43d3f 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -226,97 +226,6 @@ func TestPoolNonBlockingConnections(t *testing.T) { } -func TestAcquireTimeoutSanity(t *testing.T) { - t.Parallel() - - config := pgx.ConnPoolConfig{ - ConnConfig: mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")), - MaxConnections: 1, - } - - // case 1: default 0 value - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Expected NewConnPool with default config.AcquireTimeout not to fail, instead it failed with '%v'", err) - } - pool.Close() - - // case 2: negative value - config.AcquireTimeout = -1 * time.Second - _, err = pgx.NewConnPool(config) - if err == nil { - t.Fatal("Expected NewConnPool with negative config.AcquireTimeout to fail, instead it did not") - } - - // case 3: positive value - config.AcquireTimeout = 1 * time.Second - pool, err = pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Expected NewConnPool with positive config.AcquireTimeout not to fail, instead it failed with '%v'", err) - } - defer pool.Close() -} - -func TestPoolWithAcquireTimeoutSet(t *testing.T) { - t.Parallel() - - connAllocTimeout := 2 * time.Second - config := pgx.ConnPoolConfig{ - ConnConfig: mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")), - MaxConnections: 1, - AcquireTimeout: connAllocTimeout, - } - - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() - - // Consume all connections ... - allConnections := acquireAllConnections(t, pool, config.MaxConnections) - defer releaseAllConnections(pool, allConnections) - - // ... then try to consume 1 more. It should fail after a short timeout. - _, timeTaken, err := acquireWithTimeTaken(pool) - - if err == nil || err != pgx.ErrAcquireTimeout { - t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) - } - if timeTaken < connAllocTimeout { - t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) - } -} - -func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { - t.Parallel() - - maxConnections := 1 - pool := createConnPool(t, maxConnections) - defer pool.Close() - - // Consume all connections ... - allConnections := acquireAllConnections(t, pool, maxConnections) - - // ... then try to consume 1 more. It should hang forever. - // To unblock it we release the previously taken connection in a goroutine. - stopDeadWaitTimeout := 5 * time.Second - timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { - releaseAllConnections(pool, allConnections) - }) - defer timer.Stop() - - conn, timeTaken, err := acquireWithTimeTaken(pool) - if err == nil { - pool.Release(conn) - } else { - t.Fatalf("Expected error to be nil, instead it was '%v'", err) - } - if timeTaken < stopDeadWaitTimeout { - t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) - } -} - func TestPoolErrClosedPool(t *testing.T) { t.Parallel() @@ -334,47 +243,6 @@ func TestPoolErrClosedPool(t *testing.T) { } } -func TestPoolReleaseWithTransactions(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - conn, err := pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - mustExec(t, conn, "begin") - if _, err = conn.Exec(context.Background(), "selct"); err == nil { - t.Fatal("Did not receive expected error") - } - - if conn.TxStatus() != 'E' { - t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus()) - } - - pool.Release(conn) - - if conn.TxStatus() != 'I' { - t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus()) - } - - conn, err = pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - mustExec(t, conn, "begin") - if conn.TxStatus() != 'T' { - t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus()) - } - - pool.Release(conn) - - if conn.TxStatus() != 'I' { - t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus()) - } -} - func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { t.Parallel() @@ -835,37 +703,6 @@ func TestConnPoolQueryRow(t *testing.T) { } } -func TestConnPoolExec(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - results, err := pool.Exec(context.Background(), "create temporary table foo(id integer primary key);") - if err != nil { - t.Fatalf("Unexpected error from pool.Exec: %v", err) - } - if string(results) != "CREATE TABLE" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - results, err = pool.Exec(context.Background(), "insert into foo(id) values($1)", 1) - if err != nil { - t.Fatalf("Unexpected error from pool.Exec: %v", err) - } - if string(results) != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - results, err = pool.Exec(context.Background(), "drop table foo;") - if err != nil { - t.Fatalf("Unexpected error from pool.Exec: %v", err) - } - if string(results) != "DROP TABLE" { - t.Errorf("Unexpected results from Exec: %v", results) - } -} - func TestConnPoolPrepare(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index eb880db1..a6ffb3f8 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/jackc/pgconn v0.0.0-20190405170659-7ad3625edd3b github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3 v1.0.0 + github.com/jackc/puddle v0.0.0-20190409004018-0d93e0ec116a github.com/pkg/errors v0.8.1 github.com/satori/go.uuid v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 diff --git a/go.sum b/go.sum index 8e1d0f49..59ad6856 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I= github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/puddle v0.0.0-20190409004018-0d93e0ec116a h1:zx0j45Wa4oRefVk0D3muLxUujnMWN7ZRraF+78DXEwE= +github.com/jackc/puddle v0.0.0-20190409004018-0d93e0ec116a/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pool/common_test.go b/pool/common_test.go new file mode 100644 index 00000000..a8abf71f --- /dev/null +++ b/pool/common_test.go @@ -0,0 +1,106 @@ +package pool_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx" + "github.com/jackc/pgx/pool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Conn.Release is an asynchronous process that returns immediately. There is no signal when the actual work is +// completed. To test something that relies on the actual work for Conn.Release being completed we must simply wait. +// This function wraps the sleep so there is more meaning for the callers. +func waitForReleaseToComplete() { + time.Sleep(5 * time.Millisecond) +} + +type execer interface { + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +func testExec(t *testing.T, db execer) { + results, err := db.Exec(context.Background(), "create table foo(id integer primary key);") + require.NoError(t, err) + assert.Equal(t, "CREATE TABLE", string(results)) + + results, err = db.Exec(context.Background(), "insert into foo(id) values($1)", 1) + require.NoError(t, err) + assert.Equal(t, "INSERT 0 1", string(results)) + + results, err = db.Exec(context.Background(), "drop table foo;") + require.NoError(t, err) + assert.Equal(t, "DROP TABLE", string(results)) +} + +type queryer interface { + Query(sql string, args ...interface{}) (*pool.Rows, error) +} + +func testQuery(t *testing.T, db queryer) { + var sum, rowCount int32 + + rows, err := db.Query("select generate_series(1,$1)", 10) + require.NoError(t, err) + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, int32(10), rowCount) + assert.Equal(t, int32(55), sum) +} + +type queryExer interface { + QueryEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) (*pool.Rows, error) +} + +func testQueryEx(t *testing.T, db queryExer) { + var sum, rowCount int32 + + rows, err := db.QueryEx(context.Background(), "select generate_series(1,$1)", nil, 10) + require.NoError(t, err) + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, int32(10), rowCount) + assert.Equal(t, int32(55), sum) +} + +type queryRower interface { + QueryRow(sql string, args ...interface{}) *pool.Row +} + +func testQueryRow(t *testing.T, db queryRower) { + var what, who string + err := db.QueryRow("select 'hello', $1", "world").Scan(&what, &who) + assert.NoError(t, err) + assert.Equal(t, "hello", what) + assert.Equal(t, "world", who) +} + +type queryRowExer interface { + QueryRowEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) *pool.Row +} + +func testQueryRowEx(t *testing.T, db queryRowExer) { + var what, who string + err := db.QueryRowEx(context.Background(), "select 'hello', $1", nil, "world").Scan(&what, &who) + assert.NoError(t, err) + assert.Equal(t, "hello", what) + assert.Equal(t, "world", who) +} diff --git a/pool/conn.go b/pool/conn.go new file mode 100644 index 00000000..54db253e --- /dev/null +++ b/pool/conn.go @@ -0,0 +1,85 @@ +package pool + +import ( + "context" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx" + "github.com/jackc/puddle" +) + +// Conn is an acquired *pgx.Conn from a Pool. +type Conn struct { + res *puddle.Resource +} + +// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. +// However, it is safe to call Release multiple times. Subsequent calls after the first will be ignored. +func (c *Conn) Release() { + if c.res == nil { + return + } + + conn := c.Conn() + res := c.res + c.res = nil + + go func() { + if !conn.IsAlive() { + res.Destroy() + return + } + + if conn.TxStatus() != 'I' { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _, err := conn.Exec(ctx, "rollback") + cancel() + if err != nil { + res.Destroy() + return + } + } + + if conn.IsAlive() { + res.Release() + } else { + res.Destroy() + } + }() +} + +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + conn := c.res.Value().(*pgx.Conn) + return conn.Exec(ctx, sql, arguments...) +} + +func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { + r, err := c.res.Value().(*pgx.Conn).Query(sql, args...) + rows := &Rows{r: r, err: err} + return rows, err +} + +func (c *Conn) QueryEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) (*Rows, error) { + r, err := c.res.Value().(*pgx.Conn).QueryEx(ctx, sql, options, args...) + rows := &Rows{r: r, err: err} + return rows, err +} + +func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { + r := c.res.Value().(*pgx.Conn).QueryRow(sql, args...) + return &Row{r: r} +} + +func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) *Row { + r := c.res.Value().(*pgx.Conn).QueryRowEx(ctx, sql, options, args...) + return &Row{r: r} +} + +func (c *Conn) Begin() (*pgx.Tx, error) { + return c.res.Value().(*pgx.Conn).Begin() +} + +func (c *Conn) Conn() *pgx.Conn { + return c.res.Value().(*pgx.Conn) +} diff --git a/pool/conn_test.go b/pool/conn_test.go new file mode 100644 index 00000000..e7f39050 --- /dev/null +++ b/pool/conn_test.go @@ -0,0 +1,70 @@ +package pool_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/pool" + "github.com/stretchr/testify/require" +) + +func TestConnExec(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testExec(t, c) +} + +func TestConnQuery(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testQuery(t, c) +} + +func TestConnQueryEx(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testQueryEx(t, c) +} + +func TestConnQueryRow(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testQueryRow(t, c) +} + +func TestConnlQueryRowEx(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testQueryRowEx(t, c) +} diff --git a/pool/pool.go b/pool/pool.go new file mode 100644 index 00000000..2d2d9be1 --- /dev/null +++ b/pool/pool.go @@ -0,0 +1,132 @@ +package pool + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx" + "github.com/jackc/puddle" +) + +type Pool struct { + p *puddle.Pool +} + +// Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial +// connection. +func Connect(ctx context.Context, connString string) (*Pool, error) { + p := &Pool{} + + maxConnections := 5 // TODO - unhard-code + p.p = puddle.NewPool( + func(ctx context.Context) (interface{}, error) { return pgx.Connect(ctx, connString) }, + func(value interface{}) { value.(*pgx.Conn).Close() }, + maxConnections) + + // Initially establish one connection + res, err := p.p.Acquire(ctx) + if err != nil { + p.p.Close() + return nil, err + } + res.Release() + + return p, nil +} + +// Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned +// to pool and closed. +func (p *Pool) Close() { + p.p.Close() +} + +func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { + res, err := p.p.Acquire(ctx) + if err != nil { + return nil, err + } + + return &Conn{res: res}, nil +} + +func (p *Pool) Stat() *Stat { + return &Stat{s: p.p.Stat()} +} + +func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + c, err := p.Acquire(ctx) + if err != nil { + return "", err + } + defer c.Release() + + return c.Exec(ctx, sql, arguments...) +} + +func (p *Pool) Query(sql string, args ...interface{}) (*Rows, error) { + c, err := p.Acquire(context.Background()) + if err != nil { + return &Rows{err: err}, err + } + + rows, err := c.Query(sql, args...) + if err == nil { + rows.c = c + } else { + c.Release() + } + + return rows, err +} + +func (p *Pool) QueryEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) (*Rows, error) { + c, err := p.Acquire(context.Background()) + if err != nil { + return &Rows{err: err}, err + } + + rows, err := c.QueryEx(ctx, sql, options, args...) + if err == nil { + rows.c = c + } else { + c.Release() + } + + return rows, err +} + +func (p *Pool) QueryRow(sql string, args ...interface{}) *Row { + c, err := p.Acquire(context.Background()) + if err != nil { + return &Row{err: err} + } + + row := c.QueryRow(sql, args...) + row.c = c + return row +} + +func (p *Pool) QueryRowEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) *Row { + c, err := p.Acquire(context.Background()) + if err != nil { + return &Row{err: err} + } + + row := c.QueryRowEx(ctx, sql, options, args...) + row.c = c + return row +} + +func (p *Pool) Begin() (*Tx, error) { + c, err := p.Acquire(context.Background()) + if err != nil { + return nil, err + } + + t, err := c.Begin() + if err != nil { + return nil, err + } + + return &Tx{t: t, c: c}, err +} diff --git a/pool/pool_test.go b/pool/pool_test.go new file mode 100644 index 00000000..074e717b --- /dev/null +++ b/pool/pool_test.go @@ -0,0 +1,211 @@ +package pool_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/pool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnect(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + pool.Close() +} + +func TestConnectCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + assert.Nil(t, pool) + assert.Equal(t, context.Canceled, err) +} + +func TestPoolAcquireAndConnRelease(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + c.Release() +} + +func TestPoolExec(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testExec(t, pool) +} + +func TestPoolQuery(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + // Test common usage + testQuery(t, pool) + waitForReleaseToComplete() + + // Test expected pool behavior + rows, err := pool.Query("select generate_series(1,$1)", 10) + require.NoError(t, err) + + stats := pool.Stat() + assert.Equal(t, 1, stats.AcquiredConns()) + assert.Equal(t, 1, stats.TotalConns()) + + rows.Close() + assert.NoError(t, rows.Err()) + waitForReleaseToComplete() + + stats = pool.Stat() + assert.Equal(t, 0, stats.AcquiredConns()) + assert.Equal(t, 1, stats.TotalConns()) + +} + +func TestPoolQueryEx(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + // Test common usage + testQueryEx(t, pool) + waitForReleaseToComplete() + + // Test expected pool behavior + + rows, err := pool.QueryEx(context.Background(), "select generate_series(1,$1)", nil, 10) + require.NoError(t, err) + + stats := pool.Stat() + assert.Equal(t, 1, stats.AcquiredConns()) + assert.Equal(t, 1, stats.TotalConns()) + + rows.Close() + assert.NoError(t, rows.Err()) + waitForReleaseToComplete() + + stats = pool.Stat() + assert.Equal(t, 0, stats.AcquiredConns()) + assert.Equal(t, 1, stats.TotalConns()) +} + +func TestPoolQueryRow(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testQueryRow(t, pool) + waitForReleaseToComplete() + + stats := pool.Stat() + assert.Equal(t, 0, stats.AcquiredConns()) + assert.Equal(t, 1, stats.TotalConns()) +} + +func TestPoolQueryRowEx(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testQueryRowEx(t, pool) + waitForReleaseToComplete() + + stats := pool.Stat() + assert.Equal(t, 0, stats.AcquiredConns()) + assert.Equal(t, 1, stats.TotalConns()) +} + +func TestConnReleaseRollsBackFailedTransaction(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + pid := c.Conn().PID() + + assert.Equal(t, byte('I'), c.Conn().TxStatus()) + + _, err = c.Exec(ctx, "begin") + assert.NoError(t, err) + + assert.Equal(t, byte('T'), c.Conn().TxStatus()) + + _, err = c.Exec(ctx, "selct") + assert.Error(t, err) + + assert.Equal(t, byte('E'), c.Conn().TxStatus()) + + c.Release() + waitForReleaseToComplete() + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + + assert.Equal(t, pid, c.Conn().PID()) + assert.Equal(t, byte('I'), c.Conn().TxStatus()) + + c.Release() +} + +func TestConnReleaseRollsBackInTransaction(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + pid := c.Conn().PID() + + assert.Equal(t, byte('I'), c.Conn().TxStatus()) + + _, err = c.Exec(ctx, "begin") + assert.NoError(t, err) + + assert.Equal(t, byte('T'), c.Conn().TxStatus()) + + c.Release() + waitForReleaseToComplete() + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + + assert.Equal(t, pid, c.Conn().PID()) + assert.Equal(t, byte('I'), c.Conn().TxStatus()) + + c.Release() +} + +func TestConnReleaseDestroysClosedConn(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + c.Conn().Close() + + assert.Equal(t, 1, pool.Stat().TotalConns()) + + c.Release() + waitForReleaseToComplete() + + assert.Equal(t, 0, pool.Stat().TotalConns()) +} diff --git a/pool/rows.go b/pool/rows.go new file mode 100644 index 00000000..340ea54e --- /dev/null +++ b/pool/rows.go @@ -0,0 +1,76 @@ +package pool + +import ( + "github.com/jackc/pgx" +) + +type Rows struct { + r *pgx.Rows + c *Conn + err error +} + +func (rows *Rows) Close() { + rows.r.Close() + if rows.c != nil { + rows.c.Release() + rows.c = nil + } +} + +func (rows *Rows) Err() error { + if rows.err != nil { + return rows.err + } + return rows.r.Err() +} + +func (rows *Rows) FieldDescriptions() []pgx.FieldDescription { + return rows.r.FieldDescriptions() +} + +func (rows *Rows) Next() bool { + if rows.err != nil { + return false + } + + n := rows.r.Next() + if !n { + rows.Close() + } + return n +} + +func (rows *Rows) Scan(dest ...interface{}) error { + err := rows.r.Scan(dest...) + if err != nil { + rows.Close() + } + return err +} + +func (rows *Rows) Values() ([]interface{}, error) { + values, err := rows.r.Values() + if err != nil { + rows.Close() + } + return values, err +} + +type Row struct { + r *pgx.Row + c *Conn + err error +} + +func (row *Row) Scan(dest ...interface{}) error { + if row.err != nil { + return row.err + } + + err := row.r.Scan(dest...) + if row.c != nil { + row.c.Release() + } + return err +} diff --git a/pool/stat.go b/pool/stat.go new file mode 100644 index 00000000..186eefd5 --- /dev/null +++ b/pool/stat.go @@ -0,0 +1,47 @@ +package pool + +import ( + "time" + + "github.com/jackc/puddle" +) + +type Stat struct { + s *puddle.Stat +} + +func (s *Stat) AcquireCount() int64 { + return s.s.AcquireCount() +} + +func (s *Stat) AcquireDuration() time.Duration { + return s.s.AcquireDuration() +} + +func (s *Stat) AcquiredConns() int { + return s.s.AcquiredResources() +} + +func (s *Stat) CanceledAcquireCount() int64 { + return s.s.CanceledAcquireCount() +} + +func (s *Stat) ConstructingConns() int { + return s.s.ConstructingResources() +} + +func (s *Stat) EmptyAcquireCount() int64 { + return s.s.EmptyAcquireCount() +} + +func (s *Stat) IdleConns() int { + return s.s.IdleResources() +} + +func (s *Stat) MaxConns() int { + return s.s.MaxResources() +} + +func (s *Stat) TotalConns() int { + return s.s.TotalResources() +} diff --git a/pool/todo.txt b/pool/todo.txt new file mode 100644 index 00000000..10dc9667 --- /dev/null +++ b/pool/todo.txt @@ -0,0 +1,8 @@ +func (p *ConnPool) Begin() (*Tx, error) +func (p *ConnPool) BeginBatch() *Batch +func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) +func (p *ConnPool) Close() +func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) +func (p *ConnPool) Deallocate(name string) (err error) +func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) +func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) diff --git a/pool/tx.go b/pool/tx.go new file mode 100644 index 00000000..cef15ea8 --- /dev/null +++ b/pool/tx.go @@ -0,0 +1,55 @@ +package pool + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx" +) + +type Tx struct { + t *pgx.Tx + c *Conn +} + +func (tx *Tx) Commit() error { + err := tx.t.Commit() + if tx.c != nil { + tx.c.Release() + tx.c = nil + } + return err +} + +func (tx *Tx) Rollback() error { + err := tx.t.Rollback() + if tx.c != nil { + tx.c.Release() + tx.c = nil + } + return err +} + +func (tx *Tx) Err() error { + return tx.t.Err() +} + +func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + return tx.c.Exec(ctx, sql, arguments...) +} + +func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { + return tx.c.Query(sql, args...) +} + +func (tx *Tx) QueryEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) (*Rows, error) { + return tx.c.QueryEx(ctx, sql, options, args...) +} + +func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { + return tx.c.QueryRow(sql, args...) +} + +func (tx *Tx) QueryRowEx(ctx context.Context, sql string, options *pgx.QueryExOptions, args ...interface{}) *Row { + return tx.c.QueryRowEx(ctx, sql, options, args...) +} diff --git a/pool/tx_test.go b/pool/tx_test.go new file mode 100644 index 00000000..518ba196 --- /dev/null +++ b/pool/tx_test.go @@ -0,0 +1,70 @@ +package pool_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/pool" + "github.com/stretchr/testify/require" +) + +func TestTxExec(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testExec(t, tx) +} + +func TestTxQuery(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testQuery(t, tx) +} + +func TestTxQueryEx(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testQueryEx(t, tx) +} + +func TestTxQueryRow(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testQueryRow(t, tx) +} + +func TestTxQueryRowEx(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testQueryRowEx(t, tx) +} diff --git a/private_test.go b/private_test.go deleted file mode 100644 index dd76b43e..00000000 --- a/private_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package pgx - -// This file contains methods that expose internal pgx state to tests. - -func (c *Conn) TxStatus() byte { - return c.pgConn.TxStatus -}