pgx/stdlib/sql_test.go
Jack Christensen c53c9e6eb5 Remove simple protocol and one round trip query options
It is impossible to guarantee that the a query executed with the simple
protocol will behave the same as with the extended protocol. This is
because the normal pgx path relies on knowing the OID of query
parameters. Without this encoding a value can only be determined by the
value instead of the combination of value and PostgreSQL type. For
example, how should a []int32 be encoded? It might be encoded into a
PostgreSQL int4[] or json.

Removal also simplifies the core query path.

The primary reason for the simple protocol is for servers like PgBouncer
that may not be able to support normal prepared statements. After
further research it appears that issuing a "flush" instead "sync" after
preparing the unnamed statement would allow PgBouncer to work.

The one round trip mode can be better handled with prepared statements.

As a last resort, all original server functionality can still be accessed by
dropping down to PgConn.
2019-04-13 11:39:01 -05:00

1075 lines
23 KiB
Go

package stdlib_test
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"math"
"reflect"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgmock"
"github.com/jackc/pgx/stdlib"
)
func closeDB(t *testing.T, db *sql.DB) {
err := db.Close()
if err != nil {
t.Fatalf("db.Close unexpectedly failed: %v", err)
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, db *sql.DB) {
var sum, rowCount int32
rows, err := db.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("db.Query failed: %v", err)
}
defer rows.Close()
for rows.Next() {
var n int32
rows.Scan(&n)
sum += n
rowCount++
}
if rows.Err() != nil {
t.Fatalf("db.Query failed: %v", err)
}
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
}
if sum != 55 {
t.Error("Wrong values returned")
}
}
type preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
stmt, err := p.Prepare(sql)
if err != nil {
t.Fatalf("%v Prepare unexpectedly failed: %v", p, err)
}
return stmt
}
func closeStmt(t *testing.T, stmt *sql.Stmt) {
err := stmt.Close()
if err != nil {
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
}
}
func TestSQLOpen(t *testing.T) {
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
closeDB(t, db)
}
func TestNormalLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10))
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestStmtExec(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
tx, err := db.Begin()
if err != nil {
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
_, err = createStmt.Exec()
if err != nil {
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
}
closeStmt(t, createStmt)
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
result, err := insertStmt.Exec("foo")
if err != nil {
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
}
n, err := result.RowsAffected()
if err != nil {
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("Expected 1, received %d", n)
}
closeStmt(t, insertStmt)
if err != nil {
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestQueryCloseRowsEarly(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10))
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
// Close rows immediately without having read them
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
// Run the query again to ensure the connection and statement are still ok
rows, err = stmt.Query(int32(1), int32(10))
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnExec(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(a varchar not null)")
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
result, err := db.Exec("insert into t values('hey')")
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
n, err := result.RowsAffected()
if err != nil {
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("Expected 1, received %d", n)
}
ensureConnValid(t, db)
}
func TestConnQuery(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
if err != nil {
t.Fatalf("db.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
type testLog struct {
lvl pgx.LogLevel
msg string
data map[string]interface{}
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) {
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
func TestConnQueryNull(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
rows, err := db.Query("select $1::int", nil)
if err != nil {
t.Fatalf("db.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var n sql.NullInt64
if err := rows.Scan(&n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if n.Valid != false {
t.Errorf("Expected n to be null, but it was %v", n)
}
}
err = rows.Err()
if err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 1 {
t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnQueryRowByteSlice(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
expected := []byte{222, 173, 190, 239}
var actual []byte
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if bytes.Compare(actual, expected) != 0 {
t.Fatalf("Expected %v, but got %v", expected, actual)
}
ensureConnValid(t, db)
}
func TestConnQueryFailure(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Query("select 'foo")
if _, ok := err.(*pgconn.PgError); !ok {
t.Fatalf("Expected db.Query to return pgconn.PgError, but instead received: %v", err)
}
ensureConnValid(t, db)
}
// Test type that pgx would handle natively in binary, but since it is not a
// database/sql native type should be passed through as a string
func TestConnQueryRowPgxBinary(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
sql := "select $1::int4[]"
expected := "{1,2,3}"
var actual string
err := db.QueryRow(sql, expected).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if actual != expected {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
}
ensureConnValid(t, db)
}
func TestConnQueryRowUnknownType(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
sql := "select $1::point"
expected := "(1,2)"
var actual string
err := db.QueryRow(sql, expected).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if actual != expected {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
}
ensureConnValid(t, db)
}
func TestConnQueryJSONIntoByteSlice(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec(`
create temporary table docs(
body json not null
);
insert into docs(body) values('{"foo":"bar"}');
`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
sql := `select * from docs`
expected := []byte(`{"foo":"bar"}`)
var actual []byte
err = db.QueryRow(sql).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
}
_, err = db.Exec(`drop table docs`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec(`
create temporary table docs(
body json not null
);
`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
expected := []byte(`{"foo":"bar"}`)
_, err = db.Exec(`insert into docs(body) values($1)`, expected)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
var actual []byte
err = db.QueryRow(`select body from docs`).Scan(&actual)
if err != nil {
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
}
_, err = db.Exec(`drop table docs`)
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestTransactionLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(a varchar not null)")
if err != nil {
t.Fatalf("db.Exec unexpectedly failed: %v", err)
}
tx, err := db.Begin()
if err != nil {
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
_, err = tx.Exec("insert into t values('hi')")
if err != nil {
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("tx.Rollback unexpectedly failed: %v", err)
}
var n int64
err = db.QueryRow("select count(*) from t").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
}
if n != 0 {
t.Fatalf("Expected 0 rows due to rollback, instead found %d", n)
}
tx, err = db.Begin()
if err != nil {
t.Fatalf("db.Begin unexpectedly failed: %v", err)
}
_, err = tx.Exec("insert into t values('hi')")
if err != nil {
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
}
err = db.QueryRow("select count(*) from t").Scan(&n)
if err != nil {
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
}
if n != 1 {
t.Fatalf("Expected 1 rows due to rollback, instead found %d", n)
}
ensureConnValid(t, db)
}
func TestConnBeginTxIsolation(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
var defaultIsoLevel string
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
if err != nil {
t.Fatalf("QueryRow failed: %v", err)
}
supportedTests := []struct {
sqlIso sql.IsolationLevel
pgIso string
}{
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
}
for i, tt := range supportedTests {
func() {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
if err != nil {
t.Errorf("%d. BeginTx failed: %v", i, err)
return
}
defer tx.Rollback()
var pgIso string
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", i, err)
}
if pgIso != tt.pgIso {
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
}
}()
}
unsupportedTests := []struct {
sqlIso sql.IsolationLevel
}{
{sqlIso: sql.LevelWriteCommitted},
{sqlIso: sql.LevelLinearizable},
}
for i, tt := range unsupportedTests {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
if err == nil {
t.Errorf("%d. BeginTx should have failed", i)
tx.Rollback()
}
}
ensureConnValid(t, db)
}
func TestConnBeginTxReadOnly(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatalf("BeginTx failed: %v", err)
}
defer tx.Rollback()
var pgReadOnly string
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
if err != nil {
t.Errorf("QueryRow failed: %v", err)
}
if pgReadOnly != "on" {
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
}
ensureConnValid(t, db)
}
func TestBeginTxContextCancel(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("drop table if exists t")
if err != nil {
t.Fatalf("db.Exec failed: %v", err)
}
ctx, cancelFn := context.WithCancel(context.Background())
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatalf("BeginTx failed: %v", err)
}
_, err = tx.Exec("create table t(id serial)")
if err != nil {
t.Fatalf("tx.Exec failed: %v", err)
}
cancelFn()
err = tx.Commit()
if err != context.Canceled && err != sql.ErrTxDone {
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
}
var n int
err = db.QueryRow("select count(*) from t").Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
}
ensureConnValid(t, db)
}
func acceptStandardPgxConn(backend *pgproto3.Backend) error {
script := pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
err := script.Run(backend)
if err != nil {
return err
}
typeScript := pgmock.Script{
Steps: pgmock.PgxInitSteps(),
}
return typeScript.Run(backend)
}
func TestAcquireConn(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
var conns []*pgx.Conn
for i := 1; i < 6; i++ {
conn, err := stdlib.AcquireConn(db)
if err != nil {
t.Errorf("%d. AcquireConn failed: %v", i, err)
continue
}
var n int32
err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", i, err)
}
if n != 1 {
t.Errorf("%d. n => %d, want %d", i, n, 1)
}
stats := db.Stats()
if stats.OpenConnections != i {
t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
}
conns = append(conns, conn)
}
for i, conn := range conns {
if err := stdlib.ReleaseConn(db, conn); err != nil {
t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
}
}
ensureConnValid(t, db)
}
func TestConnPingContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("db.PingContext failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnPrepareContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
stmt, err := db.PrepareContext(context.Background(), "select now()")
if err != nil {
t.Fatalf("db.PrepareContext failed: %v", err)
}
stmt.Close()
ensureConnValid(t, db)
}
func TestConnExecContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
if err != nil {
t.Fatalf("db.ExecContext failed: %v", err)
}
ensureConnValid(t, db)
}
func TestConnExecContextFailureRetry(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
// we get a connection, immediately close it, and then get it back
{
conn, err := stdlib.AcquireConn(db)
if err != nil {
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close(context.Background())
stdlib.ReleaseConn(db, conn)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatalf("db.Conn unexpectedly failed: %v", err)
}
if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn {
t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err)
}
}
func TestConnQueryContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
if err != nil {
t.Fatalf("db.QueryContext failed: %v", err)
}
for rows.Next() {
var n int64
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
ensureConnValid(t, db)
}
func TestConnQueryContextFailureRetry(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
// we get a connection, immediately close it, and then get it back
{
conn, err := stdlib.AcquireConn(db)
if err != nil {
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close(context.Background())
stdlib.ReleaseConn(db, conn)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatalf("db.Conn unexpectedly failed: %v", err)
}
if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn {
t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err)
}
}
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
rows, err := db.Query("select * from generate_series(1,10) n")
if err != nil {
t.Fatalf("db.Query failed: %v", err)
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("rows.ColumnTypes failed: %v", err)
}
if len(columnTypes) != 1 {
t.Fatalf("len(columnTypes) => %v, want %v", len(columnTypes), 1)
}
if columnTypes[0].DatabaseTypeName() != "INT4" {
t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4")
}
rows.Close()
ensureConnValid(t, db)
}
func TestStmtExecContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(id int primary key)")
if err != nil {
t.Fatalf("db.Exec failed: %v", err)
}
stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
_, err = stmt.ExecContext(context.Background(), 42)
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, db)
}
func TestStmtExecContextCancel(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(id int primary key)")
if err != nil {
t.Fatalf("db.Exec failed: %v", err)
}
stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = stmt.ExecContext(ctx, 42)
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}
ensureConnValid(t, db)
}
func TestStmtQueryContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(context.Background(), 5)
if err != nil {
t.Fatalf("stmt.QueryContext failed: %v", err)
}
for rows.Next() {
var n int64
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
ensureConnValid(t, db)
}
func TestRowsColumnTypes(t *testing.T) {
columnTypesTests := []struct {
Name string
TypeName string
Length struct {
Len int64
OK bool
}
DecimalSize struct {
Precision int64
Scale int64
OK bool
}
ScanType reflect.Type
}{
{
Name: "a",
TypeName: "INT4",
Length: struct {
Len int64
OK bool
}{
Len: 0,
OK: false,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 0,
Scale: 0,
OK: false,
},
ScanType: reflect.TypeOf(int32(0)),
}, {
Name: "bar",
TypeName: "TEXT",
Length: struct {
Len int64
OK bool
}{
Len: math.MaxInt64,
OK: true,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 0,
Scale: 0,
OK: false,
},
ScanType: reflect.TypeOf(""),
}, {
Name: "dec",
TypeName: "NUMERIC",
Length: struct {
Len int64
OK bool
}{
Len: 0,
OK: false,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 9,
Scale: 2,
OK: true,
},
ScanType: reflect.TypeOf(float64(0)),
},
}
db := openDB(t)
defer closeDB(t, db)
rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec")
if err != nil {
t.Fatal(err)
}
columns, err := rows.ColumnTypes()
if err != nil {
t.Fatal(err)
}
if len(columns) != 3 {
t.Errorf("expected 3 columns found %d", len(columns))
}
for i, tt := range columnTypesTests {
c := columns[i]
if c.Name() != tt.Name {
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
}
if c.DatabaseTypeName() != tt.TypeName {
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
}
l, ok := c.Length()
if l != tt.Length.Len {
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
}
if ok != tt.Length.OK {
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
}
p, s, ok := c.DecimalSize()
if p != tt.DecimalSize.Precision {
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
}
if s != tt.DecimalSize.Scale {
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
}
if ok != tt.DecimalSize.OK {
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
}
if c.ScanType() != tt.ScanType {
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
}
}
}
// https://github.com/jackc/pgx/issues/409
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
var msg json.RawMessage
err := db.QueryRow("select '{}'::json").Scan(&msg)
if err != nil {
t.Fatalf("QueryRow / Scan failed: %v", err)
}
if bytes.Compare([]byte("{}"), []byte(msg)) != 0 {
t.Fatalf("Expected %v, got %v", []byte("{}"), msg)
}
ensureConnValid(t, db)
}