diff --git a/enum_array_test.go b/enum_array_test.go index 9cc950af..052a813c 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "reflect" "testing" @@ -10,12 +11,12 @@ import ( func TestEnumArrayTranscode(t *testing.T) { setupConn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, setupConn) + defer testutil.MustCloseContext(t, setupConn) - if _, err := setupConn.Exec("drop type if exists color"); err != nil { + if _, err := setupConn.Exec(context.Background(), "drop type if exists color"); err != nil { t.Fatal(err) } - if _, err := setupConn.Exec("create type color as enum ('red', 'green', 'blue')"); err != nil { + if _, err := setupConn.Exec(context.Background(), "create type color as enum ('red', 'green', 'blue')"); err != nil { t.Fatal(err) } diff --git a/hstore_array_test.go b/hstore_array_test.go index d629a04b..c8104d28 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "reflect" "testing" @@ -11,7 +12,7 @@ import ( func TestHstoreArrayTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} @@ -77,7 +78,7 @@ func TestHstoreArrayTranscode(t *testing.T) { } var result pgtype.HstoreArray - err := conn.QueryRow("test", vEncoder).Scan(&result) + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) if err != nil { t.Errorf("%v: %v", fc.name, err) continue diff --git a/jsonb_test.go b/jsonb_test.go index ab743151..afc51019 100644 --- a/jsonb_test.go +++ b/jsonb_test.go @@ -11,7 +11,7 @@ import ( func TestJSONBTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { t.Skip("Skipping due to no jsonb type") } diff --git a/line_test.go b/line_test.go index 200d1d4c..077afe6b 100644 --- a/line_test.go +++ b/line_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "testing" "github.com/jackc/pgx/pgtype" @@ -15,7 +16,7 @@ func TestLineTranscode(t *testing.T) { // line may exist but not be usable on 9.3 :( var isPG93 bool - err := conn.QueryRow("select version() ~ '9.3'").Scan(&isPG93) + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) if err != nil { t.Fatal(err) } diff --git a/record_test.go b/record_test.go index 23ec2cd3..44b0e9d8 100644 --- a/record_test.go +++ b/record_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "fmt" "reflect" "testing" @@ -12,7 +13,7 @@ import ( func TestRecordTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) tests := []struct { sql string @@ -91,7 +92,7 @@ func TestRecordTranscode(t *testing.T) { ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode var result pgtype.Record - if err := conn.QueryRow(psName).Scan(&result); err != nil { + if err := conn.QueryRow(context.Background(), psName).Scan(&result); err != nil { t.Errorf("%d: %v", i, err) continue } @@ -104,9 +105,9 @@ func TestRecordTranscode(t *testing.T) { func TestRecordWithUnknownOID(t *testing.T) { conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) + defer testutil.MustCloseContext(t, conn) - _, err := conn.Exec(`drop type if exists floatrange; + _, err := conn.Exec(context.Background(), `drop type if exists floatrange; create type floatrange as range ( subtype = float8, @@ -115,10 +116,10 @@ create type floatrange as range ( if err != nil { t.Fatal(err) } - defer conn.Exec("drop type floatrange") + defer conn.Exec(context.Background(), "drop type floatrange") var result pgtype.Record - err = conn.QueryRow("select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) + err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) if err == nil { t.Errorf("expected error but none") } diff --git a/testutil/testutil.go b/testutil/testutil.go index 0effb42d..2cde9961 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -34,12 +34,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { } func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseConnectionString(os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } @@ -56,6 +51,15 @@ func MustClose(t testing.TB, conn interface { } } +func MustCloseContext(t testing.TB, conn interface { + Close(context.Context) error +}) { + err := conn.Close(context.Background()) + if err != nil { + t.Fatal(err) + } +} + type forceTextEncoder struct { e pgtype.TextEncoder } @@ -102,7 +106,7 @@ func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := MustConnectPgx(t) - defer MustClose(t, conn) + defer MustCloseContext(t, conn) ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { @@ -133,7 +137,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) + err := conn.QueryRow(context.Background(), "test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } @@ -147,7 +151,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := MustConnectPgx(t) - defer MustClose(t, conn) + defer MustCloseContext(t, conn) for i, v := range values { // Derefence value if it is a pointer @@ -158,7 +162,7 @@ func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName str } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), fmt.Sprintf("select ($1)::%s", pgTypeName), &pgx.QueryExOptions{SimpleProtocol: true}, @@ -223,7 +227,7 @@ func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc f func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { conn := MustConnectPgx(t) - defer MustClose(t, conn) + defer MustCloseContext(t, conn) formats := []struct { name string @@ -254,7 +258,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(psName).Scan(result.Interface()) + err = conn.QueryRow(context.Background(), psName).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) }