pgx/stdlib/sql_test.go

693 lines
15 KiB
Go

package stdlib_test
import (
"bytes"
"database/sql"
"sync"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/stdlib"
)
func openDB(t *testing.T) *sql.DB {
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
return db
}
func closeDB(t *testing.T, db *sql.DB) {
err := db.Close()
if err != nil {
t.Fatalf("db.Close unexpectedly failed: %v", err)
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, db *sql.DB) {
var sum, rowCount int32
rows, err := db.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("db.Query failed: %v", err)
}
defer rows.Close()
for rows.Next() {
var n int32
rows.Scan(&n)
sum += n
rowCount++
}
if rows.Err() != nil {
t.Fatalf("db.Query failed: %v", err)
}
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
}
if sum != 55 {
t.Error("Wrong values returned")
}
}
type preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
stmt, err := p.Prepare(sql)
if err != nil {
t.Fatalf("%v Prepare unexpectedly failed: %v", p, err)
}
return stmt
}
func closeStmt(t *testing.T, stmt *sql.Stmt) {
err := stmt.Close()
if err != nil {
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
}
}
func TestNormalLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10))
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestSqlOpenDoesNotHavePool(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
driver := db.Driver().(*stdlib.Driver)
if driver.Pool != nil {
t.Fatal("Did not expect driver opened through database/sql to have Pool, but it did")
}
}
func TestOpenFromConnPool(t *testing.T) {
connConfig := pgx.ConnConfig{
Host: "127.0.0.1",
User: "pgx_md5",
Password: "secret",
Database: "pgx_test",
}
config := pgx.ConnPoolConfig{ConnConfig: connConfig}
pool, err := pgx.NewConnPool(config)
if err != nil {
t.Fatalf("Unable to create connection pool: %v", err)
}
defer pool.Close()
db, err := stdlib.OpenFromConnPool(pool)
if err != nil {
t.Fatalf("Unable to create connection pool: %v", err)
}
defer closeDB(t, db)
// Can get pgx.ConnPool from driver
driver := db.Driver().(*stdlib.Driver)
if driver.Pool == nil {
t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not")
}
// Normal sql/database still works
var n int64
err = db.QueryRow("select 1").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
}
func TestOpenFromConnPoolRace(t *testing.T) {
wg := &sync.WaitGroup{}
connConfig := pgx.ConnConfig{
Host: "127.0.0.1",
User: "pgx_md5",
Password: "secret",
Database: "pgx_test",
}
config := pgx.ConnPoolConfig{ConnConfig: connConfig}
pool, err := pgx.NewConnPool(config)
if err != nil {
t.Fatalf("Unable to create connection pool: %v", err)
}
defer pool.Close()
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
db, err := stdlib.OpenFromConnPool(pool)
if err != nil {
t.Fatalf("Unable to create connection pool: %v", err)
}
defer closeDB(t, db)
// Can get pgx.ConnPool from driver
driver := db.Driver().(*stdlib.Driver)
if driver.Pool == nil {
t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not")
}
}()
}
wg.Wait()
}
func TestOpenWithDriverConfigAfterConnect(t *testing.T) {
driverConfig := stdlib.DriverConfig{
AfterConnect: func(c *pgx.Conn) error {
_, err := c.Exec("create temporary sequence pgx")
return err
},
}
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
var n int64
err = db.QueryRow("select nextval('pgx')").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("n => %d, want %d", n, 1)
}
}
func TestStmtExec(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
tx, err := db.Begin()
if err != nil {
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
_, err = createStmt.Exec()
if err != nil {
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
}
closeStmt(t, createStmt)
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
result, err := insertStmt.Exec("foo")
if err != nil {
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
}
n, err := result.RowsAffected()
if err != nil {
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("Expected 1, received %d", n)
}
closeStmt(t, insertStmt)
if err != nil {
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestQueryCloseRowsEarly(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10))
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
// Close rows immediately without having read them
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
// Run the query again to ensure the connection and statement are still ok
rows, err = stmt.Query(int32(1), int32(10))
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnExec(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(a varchar not null)")
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
result, err := db.Exec("insert into t values('hey')")
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
n, err := result.RowsAffected()
if err != nil {
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("Expected 1, received %d", n)
}
ensureConnValid(t, db)
}
func TestConnQuery(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
if err != nil {
t.Fatalf("db.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
type testLog struct {
lvl pgx.LogLevel
msg string
data map[string]interface{}
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) {
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
func TestConnQueryLog(t *testing.T) {
logger := &testLogger{}
connConfig := pgx.ConnConfig{
Host: "127.0.0.1",
User: "pgx_md5",
Password: "secret",
Database: "pgx_test",
Logger: logger,
}
config := pgx.ConnPoolConfig{ConnConfig: connConfig}
pool, err := pgx.NewConnPool(config)
if err != nil {
t.Fatalf("Unable to create connection pool: %v", err)
}
defer pool.Close()
db, err := stdlib.OpenFromConnPool(pool)
if err != nil {
t.Fatalf("Unable to create connection pool: %v", err)
}
defer closeDB(t, db)
// clear logs from initial connection
logger.logs = []testLog{}
var n int64
err = db.QueryRow("select 1").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
l := logger.logs[0]
if l.msg != "Query" {
t.Errorf("Expected to log Query, but got %v", l)
}
if l.data["sql"] != "select 1" {
t.Errorf("Expected to log Query with sql 'select 1', but got %v", l)
}
}
func TestConnQueryNull(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
rows, err := db.Query("select $1::int", nil)
if err != nil {
t.Fatalf("db.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var n sql.NullInt64
if err := rows.Scan(&n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if n.Valid != false {
t.Errorf("Expected n to be null, but it was %v", n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 1 {
t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnQueryRowByteSlice(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
expected := []byte{222, 173, 190, 239}
var actual []byte
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if bytes.Compare(actual, expected) != 0 {
t.Fatalf("Expected %v, but got %v", expected, actual)
}
ensureConnValid(t, db)
}
func TestConnQueryFailure(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Query("select 'foo")
if _, ok := err.(pgx.PgError); !ok {
t.Fatalf("Expected db.Query to return pgx.PgError, but instead received: %v", err)
}
ensureConnValid(t, db)
}
// Test type that pgx would handle natively in binary, but since it is not a
// database/sql native type should be passed through as a string
func TestConnQueryRowPgxBinary(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
sql := "select $1::int4[]"
expected := "{1,2,3}"
var actual string
err := db.QueryRow(sql, expected).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if actual != expected {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
}
ensureConnValid(t, db)
}
func TestConnQueryRowUnknownType(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
sql := "select $1::point"
expected := "(1,2)"
var actual string
err := db.QueryRow(sql, expected).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if actual != expected {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
}
ensureConnValid(t, db)
}
func TestConnQueryJSONIntoByteSlice(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec(`
create temporary table docs(
body json not null
);
insert into docs(body) values('{"foo":"bar"}');
`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
sql := `select * from docs`
expected := []byte(`{"foo":"bar"}`)
var actual []byte
err = db.QueryRow(sql).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
}
_, err = db.Exec(`drop table docs`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec(`
create temporary table docs(
body json not null
);
`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
expected := []byte(`{"foo":"bar"}`)
_, err = db.Exec(`insert into docs(body) values($1)`, expected)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
var actual []byte
err = db.QueryRow(`select body from docs`).Scan(&actual)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
}
_, err = db.Exec(`drop table docs`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestTransactionLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(a varchar not null)")
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
tx, err := db.Begin()
if err != nil {
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
_, err = tx.Exec("insert into t values('hi')")
if err != nil {
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("tx.Rollback unexpectedly failed: %v", err)
}
var n int64
err = db.QueryRow("select count(*) from t").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
}
if n != 0 {
t.Fatalf("Expected 0 rows due to rollback, instead found %d", n)
}
tx, err = db.Begin()
if err != nil {
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
_, err = tx.Exec("insert into t values('hi')")
if err != nil {
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
}
err = db.QueryRow("select count(*) from t").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("Expected 1 rows due to rollback, instead found %d", n)
}
ensureConnValid(t, db)
}