Add context timeouts for more pgxpool tests

pull/1624/head
Jack Christensen 2023-05-29 11:15:40 -05:00
parent a1a97a7ca8
commit 5f6636d028
4 changed files with 85 additions and 53 deletions

View File

@ -24,8 +24,8 @@ type execer interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
} }
func testExec(t *testing.T, db execer) { func testExec(t *testing.T, ctx context.Context, db execer) {
results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") results, err := db.Exec(ctx, "set time zone 'America/Chicago'")
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, "SET", results.String()) assert.EqualValues(t, "SET", results.String())
} }
@ -34,10 +34,10 @@ type queryer interface {
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
} }
func testQuery(t *testing.T, db queryer) { func testQuery(t *testing.T, ctx context.Context, db queryer) {
var sum, rowCount int32 var sum, rowCount int32
rows, err := db.Query(context.Background(), "select generate_series(1,$1)", 10) rows, err := db.Query(ctx, "select generate_series(1,$1)", 10)
require.NoError(t, err) require.NoError(t, err)
for rows.Next() { for rows.Next() {
@ -56,9 +56,9 @@ type queryRower interface {
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
} }
func testQueryRow(t *testing.T, db queryRower) { func testQueryRow(t *testing.T, ctx context.Context, db queryRower) {
var what, who string var what, who string
err := db.QueryRow(context.Background(), "select 'hello', $1::text", "world").Scan(&what, &who) err := db.QueryRow(ctx, "select 'hello', $1::text", "world").Scan(&what, &who)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "hello", what) assert.Equal(t, "hello", what)
assert.Equal(t, "world", who) assert.Equal(t, "world", who)
@ -68,12 +68,12 @@ type sendBatcher interface {
SendBatch(context.Context, *pgx.Batch) pgx.BatchResults SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
} }
func testSendBatch(t *testing.T, db sendBatcher) { func testSendBatch(t *testing.T, ctx context.Context, db sendBatcher) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select 1") batch.Queue("select 1")
batch.Queue("select 2") batch.Queue("select 2")
br := db.SendBatch(context.Background(), batch) br := db.SendBatch(ctx, batch)
var err error var err error
var n int32 var n int32
@ -93,12 +93,12 @@ type copyFromer interface {
CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
} }
func testCopyFrom(t *testing.T, db interface { func testCopyFrom(t *testing.T, ctx context.Context, db interface {
execer execer
queryer queryer
copyFromer copyFromer
}) { }) {
_, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) _, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
require.NoError(t, err) require.NoError(t, err)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
@ -108,11 +108,11 @@ func testCopyFrom(t *testing.T, db interface {
{nil, nil, nil, nil, nil, nil, nil}, {nil, nil, nil, nil, nil, nil, nil},
} }
copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) copyCount, err := db.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, len(inputRows), copyCount) assert.EqualValues(t, len(inputRows), copyCount)
rows, err := db.Query(context.Background(), "select * from foo") rows, err := db.Query(ctx, "select * from foo")
assert.NoError(t, err) assert.NoError(t, err)
var outputRows [][]any var outputRows [][]any

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"os" "os"
"testing" "testing"
"time"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -12,69 +13,84 @@ import (
func TestConnExec(t *testing.T) { func TestConnExec(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
c, err := pool.Acquire(context.Background()) c, err := pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
defer c.Release() defer c.Release()
testExec(t, c) testExec(t, ctx, c)
} }
func TestConnQuery(t *testing.T) { func TestConnQuery(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
c, err := pool.Acquire(context.Background()) c, err := pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
defer c.Release() defer c.Release()
testQuery(t, c) testQuery(t, ctx, c)
} }
func TestConnQueryRow(t *testing.T) { func TestConnQueryRow(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
c, err := pool.Acquire(context.Background()) c, err := pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
defer c.Release() defer c.Release()
testQueryRow(t, c) testQueryRow(t, ctx, c)
} }
func TestConnSendBatch(t *testing.T) { func TestConnSendBatch(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
c, err := pool.Acquire(context.Background()) c, err := pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
defer c.Release() defer c.Release()
testSendBatch(t, c) testSendBatch(t, ctx, c)
} }
func TestConnCopyFrom(t *testing.T) { func TestConnCopyFrom(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
c, err := pool.Acquire(context.Background()) c, err := pool.Acquire(ctx)
require.NoError(t, err) require.NoError(t, err)
defer c.Release() defer c.Release()
testCopyFrom(t, c) testCopyFrom(t, ctx, c)
} }

View File

@ -630,7 +630,7 @@ func TestPoolExec(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
testExec(t, pool) testExec(t, ctx, pool)
} }
func TestPoolQuery(t *testing.T) { func TestPoolQuery(t *testing.T) {
@ -644,7 +644,7 @@ func TestPoolQuery(t *testing.T) {
defer pool.Close() defer pool.Close()
// Test common usage // Test common usage
testQuery(t, pool) testQuery(t, ctx, pool)
waitForReleaseToComplete() waitForReleaseToComplete()
// Test expected pool behavior // Test expected pool behavior
@ -675,7 +675,7 @@ func TestPoolQueryRow(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
testQueryRow(t, pool) testQueryRow(t, ctx, pool)
waitForReleaseToComplete() waitForReleaseToComplete()
stats := pool.Stat() stats := pool.Stat()
@ -708,7 +708,7 @@ func TestPoolSendBatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
testSendBatch(t, pool) testSendBatch(t, ctx, pool)
waitForReleaseToComplete() waitForReleaseToComplete()
stats := pool.Stat() stats := pool.Stat()
@ -883,8 +883,8 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
go func() { go func() {
defer func() { done <- true }() defer func() { done <- true }()
testQuery(t, pool) testQuery(t, ctx, pool)
testQueryRow(t, pool) testQueryRow(t, ctx, pool)
}() }()
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"os" "os"
"testing" "testing"
"time"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -12,69 +13,84 @@ import (
func TestTxExec(t *testing.T) { func TestTxExec(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
tx, err := pool.Begin(context.Background()) tx, err := pool.Begin(ctx)
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback(context.Background()) defer tx.Rollback(ctx)
testExec(t, tx) testExec(t, ctx, tx)
} }
func TestTxQuery(t *testing.T) { func TestTxQuery(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
tx, err := pool.Begin(context.Background()) tx, err := pool.Begin(ctx)
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback(context.Background()) defer tx.Rollback(ctx)
testQuery(t, tx) testQuery(t, ctx, tx)
} }
func TestTxQueryRow(t *testing.T) { func TestTxQueryRow(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
tx, err := pool.Begin(context.Background()) tx, err := pool.Begin(ctx)
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback(context.Background()) defer tx.Rollback(ctx)
testQueryRow(t, tx) testQueryRow(t, ctx, tx)
} }
func TestTxSendBatch(t *testing.T) { func TestTxSendBatch(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
tx, err := pool.Begin(context.Background()) tx, err := pool.Begin(ctx)
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback(context.Background()) defer tx.Rollback(ctx)
testSendBatch(t, tx) testSendBatch(t, ctx, tx)
} }
func TestTxCopyFrom(t *testing.T) { func TestTxCopyFrom(t *testing.T) {
t.Parallel() t.Parallel()
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer pool.Close() defer pool.Close()
tx, err := pool.Begin(context.Background()) tx, err := pool.Begin(ctx)
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback(context.Background()) defer tx.Rollback(ctx)
testCopyFrom(t, tx) testCopyFrom(t, ctx, tx)
} }