Replace t.Fatal with require where possible

pull/745/head
Jack Christensen 2020-05-08 12:47:47 -05:00
parent ec53234e86
commit c3381c6911
1 changed files with 115 additions and 329 deletions

View File

@ -21,18 +21,13 @@ import (
func openDB(t testing.TB) *sql.DB { func openDB(t testing.TB) *sql.DB {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
if err != nil { require.NoError(t, err)
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
}
return stdlib.OpenDB(*config) return stdlib.OpenDB(*config)
} }
func closeDB(t testing.TB, db *sql.DB) { func closeDB(t testing.TB, db *sql.DB) {
err := db.Close() err := db.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("db.Close unexpectedly failed: %v", err)
}
} }
// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
@ -41,9 +36,7 @@ func ensureDBValid(t testing.TB, db *sql.DB) {
var sum, rowCount int32 var sum, rowCount int32
rows, err := db.Query("select generate_series(1,$1)", 10) rows, err := db.Query("select generate_series(1,$1)", 10)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Query failed: %v", err)
}
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
@ -53,9 +46,7 @@ func ensureDBValid(t testing.TB, db *sql.DB) {
rowCount++ rowCount++
} }
if rows.Err() != nil { require.NoError(t, rows.Err())
t.Fatalf("db.Query failed: %v", err)
}
if rowCount != 10 { if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times") t.Error("Select called onDataRow wrong number of times")
@ -71,25 +62,18 @@ type preparer interface {
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
stmt, err := p.Prepare(sql) stmt, err := p.Prepare(sql)
if err != nil { require.NoError(t, err)
t.Fatalf("%v Prepare unexpectedly failed: %v", p, err)
}
return stmt return stmt
} }
func closeStmt(t *testing.T, stmt *sql.Stmt) { func closeStmt(t *testing.T, stmt *sql.Stmt) {
err := stmt.Close() err := stmt.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
}
} }
func TestSQLOpen(t *testing.T) { func TestSQLOpen(t *testing.T) {
db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
if err != nil { require.NoError(t, err)
t.Fatalf("sql.Open failed: %v", err)
}
closeDB(t, db) closeDB(t, db)
} }
@ -101,9 +85,7 @@ func TestNormalLifeCycle(t *testing.T) {
defer closeStmt(t, stmt) defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10)) rows, err := stmt.Query(int32(1), int32(10))
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0) rowCount := int64(0)
@ -112,9 +94,9 @@ func TestNormalLifeCycle(t *testing.T) {
var s string var s string
var n int64 var n int64
if err := rows.Scan(&s, &n); err != nil { err := rows.Scan(&s, &n)
t.Fatalf("rows.Scan unexpectedly failed: %v", err) require.NoError(t, err)
}
if s != "foo" { if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s) t.Errorf(`Expected "foo", received "%v"`, s)
} }
@ -122,18 +104,12 @@ func TestNormalLifeCycle(t *testing.T) {
t.Errorf("Expected %d, received %d", rowCount, n) t.Errorf("Expected %d, received %d", rowCount, n)
} }
} }
err = rows.Err() require.NoError(t, rows.Err())
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err) require.EqualValues(t, 10, rowCount)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -143,36 +119,22 @@ func TestStmtExec(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { require.NoError(t, err)
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
_, err = createStmt.Exec() _, err = createStmt.Exec()
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
}
closeStmt(t, createStmt) closeStmt(t, createStmt)
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
result, err := insertStmt.Exec("foo") result, err := insertStmt.Exec("foo")
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
}
n, err := result.RowsAffected() n, err := result.RowsAffected()
if err != nil { require.NoError(t, err)
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) require.EqualValues(t, 1, n)
}
if n != 1 {
t.Fatalf("Expected 1, received %d", n)
}
closeStmt(t, insertStmt) closeStmt(t, insertStmt)
if err != nil {
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -184,21 +146,15 @@ func TestQueryCloseRowsEarly(t *testing.T) {
defer closeStmt(t, stmt) defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10)) rows, err := stmt.Query(int32(1), int32(10))
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
// Close rows immediately without having read them // Close rows immediately without having read them
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
// Run the query again to ensure the connection and statement are still ok // Run the query again to ensure the connection and statement are still ok
rows, err = stmt.Query(int32(1), int32(10)) rows, err = stmt.Query(int32(1), int32(10))
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0) rowCount := int64(0)
@ -207,9 +163,8 @@ func TestQueryCloseRowsEarly(t *testing.T) {
var s string var s string
var n int64 var n int64
if err := rows.Scan(&s, &n); err != nil { err := rows.Scan(&s, &n)
t.Fatalf("rows.Scan unexpectedly failed: %v", err) require.NoError(t, err)
}
if s != "foo" { if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s) t.Errorf(`Expected "foo", received "%v"`, s)
} }
@ -217,18 +172,11 @@ func TestQueryCloseRowsEarly(t *testing.T) {
t.Errorf("Expected %d, received %d", rowCount, n) t.Errorf("Expected %d, received %d", rowCount, n)
} }
} }
err = rows.Err() require.NoError(t, rows.Err())
if err != nil { require.EqualValues(t, 10, rowCount)
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -238,22 +186,14 @@ func TestConnExec(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.Exec("create temporary table t(a varchar not null)") _, err := db.Exec("create temporary table t(a varchar not null)")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
result, err := db.Exec("insert into t values('hey')") result, err := db.Exec("insert into t values('hey')")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
n, err := result.RowsAffected() n, err := result.RowsAffected()
if err != nil { require.NoError(t, err)
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) require.EqualValues(t, 1, n)
}
if n != 1 {
t.Fatalf("Expected 1, received %d", n)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -263,9 +203,7 @@ func TestConnQuery(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
if err != nil { require.NoError(t, err)
t.Fatalf("db.Query unexpectedly failed: %v", err)
}
rowCount := int64(0) rowCount := int64(0)
@ -274,9 +212,8 @@ func TestConnQuery(t *testing.T) {
var s string var s string
var n int64 var n int64
if err := rows.Scan(&s, &n); err != nil { err := rows.Scan(&s, &n)
t.Fatalf("rows.Scan unexpectedly failed: %v", err) require.NoError(t, err)
}
if s != "foo" { if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s) t.Errorf(`Expected "foo", received "%v"`, s)
} }
@ -284,18 +221,11 @@ func TestConnQuery(t *testing.T) {
t.Errorf("Expected %d, received %d", rowCount, n) t.Errorf("Expected %d, received %d", rowCount, n)
} }
} }
err = rows.Err() require.NoError(t, rows.Err())
if err != nil { require.EqualValues(t, 10, rowCount)
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -305,9 +235,7 @@ func TestConnQueryNull(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
rows, err := db.Query("select $1::int", nil) rows, err := db.Query("select $1::int", nil)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Query unexpectedly failed: %v", err)
}
rowCount := int64(0) rowCount := int64(0)
@ -315,25 +243,17 @@ func TestConnQueryNull(t *testing.T) {
rowCount++ rowCount++
var n sql.NullInt64 var n sql.NullInt64
if err := rows.Scan(&n); err != nil { err := rows.Scan(&n)
t.Fatalf("rows.Scan unexpectedly failed: %v", err) require.NoError(t, err)
}
if n.Valid != false { if n.Valid != false {
t.Errorf("Expected n to be null, but it was %v", n) t.Errorf("Expected n to be null, but it was %v", n)
} }
} }
err = rows.Err() require.NoError(t, rows.Err())
if err != nil { require.EqualValues(t, 1, rowCount)
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 1 {
t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount)
}
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -346,13 +266,8 @@ func TestConnQueryRowByteSlice(t *testing.T) {
var actual []byte var actual []byte
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
if err != nil { require.NoError(t, err)
t.Fatalf("db.QueryRow unexpectedly failed: %v", err) require.EqualValues(t, expected, actual)
}
if bytes.Compare(actual, expected) != 0 {
t.Fatalf("Expected %v, but got %v", expected, actual)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -362,9 +277,8 @@ func TestConnQueryFailure(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.Query("select 'foo") _, err := db.Query("select 'foo")
if _, ok := err.(*pgconn.PgError); !ok { require.Error(t, err)
t.Fatalf("Expected db.Query to return pgconn.PgError, but instead received: %v", err) require.IsType(t, new(pgconn.PgError), err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -380,13 +294,8 @@ func TestConnQueryRowPgxBinary(t *testing.T) {
var actual string var actual string
err := db.QueryRow(sql, expected).Scan(&actual) err := db.QueryRow(sql, expected).Scan(&actual)
if err != nil { require.NoError(t, err)
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) require.EqualValues(t, expected, actual)
}
if actual != expected {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -400,13 +309,8 @@ func TestConnQueryRowUnknownType(t *testing.T) {
var actual string var actual string
err := db.QueryRow(sql, expected).Scan(&actual) err := db.QueryRow(sql, expected).Scan(&actual)
if err != nil { require.NoError(t, err)
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) require.EqualValues(t, expected, actual)
}
if actual != expected {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -422,9 +326,7 @@ func TestConnQueryJSONIntoByteSlice(t *testing.T) {
insert into docs(body) values('{"foo":"bar"}'); insert into docs(body) values('{"foo":"bar"}');
`) `)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
sql := `select * from docs` sql := `select * from docs`
expected := []byte(`{"foo":"bar"}`) expected := []byte(`{"foo":"bar"}`)
@ -440,9 +342,7 @@ func TestConnQueryJSONIntoByteSlice(t *testing.T) {
} }
_, err = db.Exec(`drop table docs`) _, err = db.Exec(`drop table docs`)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -456,31 +356,23 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
body json not null body json not null
); );
`) `)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
expected := []byte(`{"foo":"bar"}`) expected := []byte(`{"foo":"bar"}`)
_, err = db.Exec(`insert into docs(body) values($1)`, expected) _, err = db.Exec(`insert into docs(body) values($1)`, expected)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
var actual []byte var actual []byte
err = db.QueryRow(`select body from docs`).Scan(&actual) err = db.QueryRow(`select body from docs`).Scan(&actual)
if err != nil { require.NoError(t, err)
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if bytes.Compare(actual, expected) != 0 { if bytes.Compare(actual, expected) != 0 {
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
} }
_, err = db.Exec(`drop table docs`) _, err = db.Exec(`drop table docs`)
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -490,56 +382,34 @@ func TestTransactionLifeCycle(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.Exec("create temporary table t(a varchar not null)") _, err := db.Exec("create temporary table t(a varchar not null)")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { require.NoError(t, err)
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
_, err = tx.Exec("insert into t values('hi')") _, err = tx.Exec("insert into t values('hi')")
if err != nil { require.NoError(t, err)
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
}
err = tx.Rollback() err = tx.Rollback()
if err != nil { require.NoError(t, err)
t.Fatalf("tx.Rollback unexpectedly failed: %v", err)
}
var n int64 var n int64
err = db.QueryRow("select count(*) from t").Scan(&n) err = db.QueryRow("select count(*) from t").Scan(&n)
if err != nil { require.NoError(t, err)
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) require.EqualValues(t, 0, n)
}
if n != 0 {
t.Fatalf("Expected 0 rows due to rollback, instead found %d", n)
}
tx, err = db.Begin() tx, err = db.Begin()
if err != nil { require.NoError(t, err)
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
_, err = tx.Exec("insert into t values('hi')") _, err = tx.Exec("insert into t values('hi')")
if err != nil { require.NoError(t, err)
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
}
err = tx.Commit() err = tx.Commit()
if err != nil { require.NoError(t, err)
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
}
err = db.QueryRow("select count(*) from t").Scan(&n) err = db.QueryRow("select count(*) from t").Scan(&n)
if err != nil { require.NoError(t, err)
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) require.EqualValues(t, 1, n)
}
if n != 1 {
t.Fatalf("Expected 1 rows due to rollback, instead found %d", n)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -550,9 +420,7 @@ func TestConnBeginTxIsolation(t *testing.T) {
var defaultIsoLevel string var defaultIsoLevel string
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
if err != nil { require.NoError(t, err)
t.Fatalf("QueryRow failed: %v", err)
}
supportedTests := []struct { supportedTests := []struct {
sqlIso sql.IsolationLevel sqlIso sql.IsolationLevel
@ -608,9 +476,7 @@ func TestConnBeginTxReadOnly(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil { require.NoError(t, err)
t.Fatalf("BeginTx failed: %v", err)
}
defer tx.Rollback() defer tx.Rollback()
var pgReadOnly string var pgReadOnly string
@ -631,21 +497,15 @@ func TestBeginTxContextCancel(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.Exec("drop table if exists t") _, err := db.Exec("drop table if exists t")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec failed: %v", err)
}
ctx, cancelFn := context.WithCancel(context.Background()) ctx, cancelFn := context.WithCancel(context.Background())
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { require.NoError(t, err)
t.Fatalf("BeginTx failed: %v", err)
}
_, err = tx.Exec("create table t(id serial)") _, err = tx.Exec("create table t(id serial)")
if err != nil { require.NoError(t, err)
t.Fatalf("tx.Exec failed: %v", err)
}
cancelFn() cancelFn()
@ -739,9 +599,8 @@ func TestConnPingContextSuccess(t *testing.T) {
db := openDB(t) db := openDB(t)
defer closeDB(t, db) defer closeDB(t, db)
if err := db.PingContext(context.Background()); err != nil { err := db.PingContext(context.Background())
t.Fatalf("db.PingContext failed: %v", err) require.NoError(t, err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -751,9 +610,7 @@ func TestConnPrepareContextSuccess(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
stmt, err := db.PrepareContext(context.Background(), "select now()") stmt, err := db.PrepareContext(context.Background(), "select now()")
if err != nil { require.NoError(t, err)
t.Fatalf("db.PrepareContext failed: %v", err)
}
stmt.Close() stmt.Close()
ensureDBValid(t, db) ensureDBValid(t, db)
@ -764,9 +621,7 @@ func TestConnExecContextSuccess(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
if err != nil { require.NoError(t, err)
t.Fatalf("db.ExecContext failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -778,19 +633,14 @@ func TestConnExecContextFailureRetry(t *testing.T) {
// we get a connection, immediately close it, and then get it back // we get a connection, immediately close it, and then get it back
{ {
conn, err := stdlib.AcquireConn(db) conn, err := stdlib.AcquireConn(db)
if err != nil { require.NoError(t, err)
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close(context.Background()) conn.Close(context.Background())
stdlib.ReleaseConn(db, conn) stdlib.ReleaseConn(db, conn)
} }
conn, err := db.Conn(context.Background()) conn, err := db.Conn(context.Background())
if err != nil { require.NoError(t, err)
t.Fatalf("db.Conn unexpectedly failed: %v", err) _, err = conn.ExecContext(context.Background(), "select 1")
} require.EqualValues(t, driver.ErrBadConn, err)
if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn {
t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err)
}
} }
func TestConnQueryContextSuccess(t *testing.T) { func TestConnQueryContextSuccess(t *testing.T) {
@ -798,20 +648,14 @@ func TestConnQueryContextSuccess(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
if err != nil { require.NoError(t, err)
t.Fatalf("db.QueryContext failed: %v", err)
}
for rows.Next() { for rows.Next() {
var n int64 var n int64
if err := rows.Scan(&n); err != nil { err := rows.Scan(&n)
t.Error(err) require.NoError(t, err)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
} }
require.NoError(t, rows.Err())
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -823,19 +667,15 @@ func TestConnQueryContextFailureRetry(t *testing.T) {
// we get a connection, immediately close it, and then get it back // we get a connection, immediately close it, and then get it back
{ {
conn, err := stdlib.AcquireConn(db) conn, err := stdlib.AcquireConn(db)
if err != nil { require.NoError(t, err)
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close(context.Background()) conn.Close(context.Background())
stdlib.ReleaseConn(db, conn) stdlib.ReleaseConn(db, conn)
} }
conn, err := db.Conn(context.Background()) conn, err := db.Conn(context.Background())
if err != nil { require.NoError(t, err)
t.Fatalf("db.Conn unexpectedly failed: %v", err)
} _, err = conn.QueryContext(context.Background(), "select 1")
if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn { require.EqualValues(t, driver.ErrBadConn, err)
t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err)
}
} }
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
@ -843,18 +683,11 @@ func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
rows, err := db.Query("select * from generate_series(1,10) n") rows, err := db.Query("select * from generate_series(1,10) n")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Query failed: %v", err)
}
columnTypes, err := rows.ColumnTypes() columnTypes, err := rows.ColumnTypes()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.ColumnTypes failed: %v", err) require.Len(t, columnTypes, 1)
}
if len(columnTypes) != 1 {
t.Fatalf("len(columnTypes) => %v, want %v", len(columnTypes), 1)
}
if columnTypes[0].DatabaseTypeName() != "INT4" { if columnTypes[0].DatabaseTypeName() != "INT4" {
t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4")
@ -870,20 +703,14 @@ func TestStmtExecContextSuccess(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.Exec("create temporary table t(id int primary key)") _, err := db.Exec("create temporary table t(id int primary key)")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec failed: %v", err)
}
stmt, err := db.Prepare("insert into t(id) values ($1::int4)") stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer stmt.Close() defer stmt.Close()
_, err = stmt.ExecContext(context.Background(), 42) _, err = stmt.ExecContext(context.Background(), 42)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -893,14 +720,10 @@ func TestStmtExecContextCancel(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
_, err := db.Exec("create temporary table t(id int primary key)") _, err := db.Exec("create temporary table t(id int primary key)")
if err != nil { require.NoError(t, err)
t.Fatalf("db.Exec failed: %v", err)
}
stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer stmt.Close() defer stmt.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
@ -919,15 +742,11 @@ func TestStmtQueryContextSuccess(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer stmt.Close() defer stmt.Close()
rows, err := stmt.QueryContext(context.Background(), 5) rows, err := stmt.QueryContext(context.Background(), 5)
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.QueryContext failed: %v", err)
}
for rows.Next() { for rows.Next() {
var n int64 var n int64
@ -1025,14 +844,10 @@ func TestRowsColumnTypes(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
columns, err := rows.ColumnTypes() columns, err := rows.ColumnTypes()
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
if len(columns) != 3 { if len(columns) != 3 {
t.Errorf("expected 3 columns found %d", len(columns)) t.Errorf("expected 3 columns found %d", len(columns))
} }
@ -1070,18 +885,14 @@ func TestRowsColumnTypes(t *testing.T) {
func TestSimpleQueryLifeCycle(t *testing.T) { func TestSimpleQueryLifeCycle(t *testing.T) {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
if err != nil { require.NoError(t, err)
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
}
config.PreferSimpleProtocol = true config.PreferSimpleProtocol = true
db := stdlib.OpenDB(*config) db := stdlib.OpenDB(*config)
defer closeDB(t, db) defer closeDB(t, db)
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0) rowCount := int64(0)
@ -1092,9 +903,8 @@ func TestSimpleQueryLifeCycle(t *testing.T) {
n int64 n int64
) )
if err := rows.Scan(&s, &n); err != nil { err := rows.Scan(&s, &n)
t.Fatalf("rows.Scan unexpectedly failed: %v", err) require.NoError(t, err)
}
if s != "foo" { if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s) t.Errorf(`Expected "foo", received "%v"`, s)
@ -1104,43 +914,24 @@ func TestSimpleQueryLifeCycle(t *testing.T) {
t.Errorf("Expected %d, received %d", rowCount, n) t.Errorf("Expected %d, received %d", rowCount, n)
} }
} }
require.NoError(t, rows.Err())
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
rows, err = db.Query("select 1 where false") rows, err = db.Query("select 1 where false")
if err != nil { require.NoError(t, err)
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount = int64(0) rowCount = int64(0)
for rows.Next() { for rows.Next() {
rowCount++ rowCount++
} }
require.NoError(t, rows.Err())
if err = rows.Err(); err != nil { require.EqualValues(t, 0, rowCount)
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 0 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close() err = rows.Close()
if err != nil { require.NoError(t, err)
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }
@ -1153,13 +944,8 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) {
var msg json.RawMessage var msg json.RawMessage
err := db.QueryRow("select '{}'::json").Scan(&msg) err := db.QueryRow("select '{}'::json").Scan(&msg)
if err != nil { require.NoError(t, err)
t.Fatalf("QueryRow / Scan failed: %v", err) require.EqualValues(t, []byte("{}"), []byte(msg))
}
if bytes.Compare([]byte("{}"), []byte(msg)) != 0 {
t.Fatalf("Expected %v, got %v", []byte("{}"), msg)
}
ensureDBValid(t, db) ensureDBValid(t, db)
} }