From 5f6636d0286baf948c81d175c25ea8810c4f02d7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2023 11:15:40 -0500 Subject: [PATCH] Add context timeouts for more pgxpool tests --- pgxpool/common_test.go | 24 +++++++++--------- pgxpool/conn_test.go | 46 +++++++++++++++++++++++----------- pgxpool/pool_test.go | 12 ++++----- pgxpool/tx_test.go | 56 +++++++++++++++++++++++++++--------------- 4 files changed, 85 insertions(+), 53 deletions(-) diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index 16f4f553..b2797027 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -24,8 +24,8 @@ type execer interface { Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) } -func testExec(t *testing.T, db execer) { - results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") +func testExec(t *testing.T, ctx context.Context, db execer) { + results, err := db.Exec(ctx, "set time zone 'America/Chicago'") require.NoError(t, err) assert.EqualValues(t, "SET", results.String()) } @@ -34,10 +34,10 @@ type queryer interface { Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) } -func testQuery(t *testing.T, db queryer) { +func testQuery(t *testing.T, ctx context.Context, db queryer) { var sum, rowCount int32 - rows, err := db.Query(context.Background(), "select generate_series(1,$1)", 10) + rows, err := db.Query(ctx, "select generate_series(1,$1)", 10) require.NoError(t, err) for rows.Next() { @@ -56,9 +56,9 @@ type queryRower interface { QueryRow(ctx context.Context, sql string, args ...any) pgx.Row } -func testQueryRow(t *testing.T, db queryRower) { +func testQueryRow(t *testing.T, ctx context.Context, db queryRower) { var what, who string - err := db.QueryRow(context.Background(), "select 'hello', $1::text", "world").Scan(&what, &who) + err := db.QueryRow(ctx, "select 'hello', $1::text", "world").Scan(&what, &who) assert.NoError(t, err) assert.Equal(t, "hello", what) assert.Equal(t, "world", who) @@ -68,12 +68,12 @@ type sendBatcher interface { SendBatch(context.Context, *pgx.Batch) pgx.BatchResults } -func testSendBatch(t *testing.T, db sendBatcher) { +func testSendBatch(t *testing.T, ctx context.Context, db sendBatcher) { batch := &pgx.Batch{} batch.Queue("select 1") batch.Queue("select 2") - br := db.SendBatch(context.Background(), batch) + br := db.SendBatch(ctx, batch) var err error var n int32 @@ -93,12 +93,12 @@ type copyFromer interface { CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) } -func testCopyFrom(t *testing.T, db interface { +func testCopyFrom(t *testing.T, ctx context.Context, db interface { execer queryer copyFromer }) { - _, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) + _, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) require.NoError(t, err) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) @@ -108,11 +108,11 @@ func testCopyFrom(t *testing.T, db interface { {nil, nil, nil, nil, nil, nil, nil}, } - copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + copyCount, err := db.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) assert.NoError(t, err) assert.EqualValues(t, len(inputRows), copyCount) - rows, err := db.Query(context.Background(), "select * from foo") + rows, err := db.Query(ctx, "select * from foo") assert.NoError(t, err) var outputRows [][]any diff --git a/pgxpool/conn_test.go b/pgxpool/conn_test.go index 175981b7..b982588c 100644 --- a/pgxpool/conn_test.go +++ b/pgxpool/conn_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "testing" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" @@ -12,69 +13,84 @@ import ( func TestConnExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - c, err := pool.Acquire(context.Background()) + c, err := pool.Acquire(ctx) require.NoError(t, err) defer c.Release() - testExec(t, c) + testExec(t, ctx, c) } func TestConnQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - c, err := pool.Acquire(context.Background()) + c, err := pool.Acquire(ctx) require.NoError(t, err) defer c.Release() - testQuery(t, c) + testQuery(t, ctx, c) } func TestConnQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - c, err := pool.Acquire(context.Background()) + c, err := pool.Acquire(ctx) require.NoError(t, err) defer c.Release() - testQueryRow(t, c) + testQueryRow(t, ctx, c) } func TestConnSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - c, err := pool.Acquire(context.Background()) + c, err := pool.Acquire(ctx) require.NoError(t, err) defer c.Release() - testSendBatch(t, c) + testSendBatch(t, ctx, c) } func TestConnCopyFrom(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - c, err := pool.Acquire(context.Background()) + c, err := pool.Acquire(ctx) require.NoError(t, err) defer c.Release() - testCopyFrom(t, c) + testCopyFrom(t, ctx, c) } diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index ae894d63..30f742cd 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -630,7 +630,7 @@ func TestPoolExec(t *testing.T) { require.NoError(t, err) defer pool.Close() - testExec(t, pool) + testExec(t, ctx, pool) } func TestPoolQuery(t *testing.T) { @@ -644,7 +644,7 @@ func TestPoolQuery(t *testing.T) { defer pool.Close() // Test common usage - testQuery(t, pool) + testQuery(t, ctx, pool) waitForReleaseToComplete() // Test expected pool behavior @@ -675,7 +675,7 @@ func TestPoolQueryRow(t *testing.T) { require.NoError(t, err) defer pool.Close() - testQueryRow(t, pool) + testQueryRow(t, ctx, pool) waitForReleaseToComplete() stats := pool.Stat() @@ -708,7 +708,7 @@ func TestPoolSendBatch(t *testing.T) { require.NoError(t, err) defer pool.Close() - testSendBatch(t, pool) + testSendBatch(t, ctx, pool) waitForReleaseToComplete() stats := pool.Stat() @@ -883,8 +883,8 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) { for i := 0; i < n; i++ { go func() { defer func() { done <- true }() - testQuery(t, pool) - testQueryRow(t, pool) + testQuery(t, ctx, pool) + testQueryRow(t, ctx, pool) }() } diff --git a/pgxpool/tx_test.go b/pgxpool/tx_test.go index 8e140bf5..a65b30f8 100644 --- a/pgxpool/tx_test.go +++ b/pgxpool/tx_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "testing" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" @@ -12,69 +13,84 @@ import ( func TestTxExec(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background()) + tx, err := pool.Begin(ctx) require.NoError(t, err) - defer tx.Rollback(context.Background()) + defer tx.Rollback(ctx) - testExec(t, tx) + testExec(t, ctx, tx) } func TestTxQuery(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background()) + tx, err := pool.Begin(ctx) require.NoError(t, err) - defer tx.Rollback(context.Background()) + defer tx.Rollback(ctx) - testQuery(t, tx) + testQuery(t, ctx, tx) } func TestTxQueryRow(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background()) + tx, err := pool.Begin(ctx) require.NoError(t, err) - defer tx.Rollback(context.Background()) + defer tx.Rollback(ctx) - testQueryRow(t, tx) + testQueryRow(t, ctx, tx) } func TestTxSendBatch(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background()) + tx, err := pool.Begin(ctx) require.NoError(t, err) - defer tx.Rollback(context.Background()) + defer tx.Rollback(ctx) - testSendBatch(t, tx) + testSendBatch(t, ctx, tx) } func TestTxCopyFrom(t *testing.T) { t.Parallel() - pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background()) + tx, err := pool.Begin(ctx) require.NoError(t, err) - defer tx.Rollback(context.Background()) + defer tx.Rollback(ctx) - testCopyFrom(t, tx) + testCopyFrom(t, ctx, tx) }