mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
It is impossible to guarantee that the a query executed with the simple protocol will behave the same as with the extended protocol. This is because the normal pgx path relies on knowing the OID of query parameters. Without this encoding a value can only be determined by the value instead of the combination of value and PostgreSQL type. For example, how should a []int32 be encoded? It might be encoded into a PostgreSQL int4[] or json. Removal also simplifies the core query path. The primary reason for the simple protocol is for servers like PgBouncer that may not be able to support normal prepared statements. After further research it appears that issuing a "flush" instead "sync" after preparing the unnamed statement would allow PgBouncer to work. The one round trip mode can be better handled with prepared statements. As a last resort, all original server functionality can still be accessed by dropping down to PgConn.
1075 lines
23 KiB
Go
1075 lines
23 KiB
Go
package stdlib_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"math"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgproto3"
|
|
"github.com/jackc/pgx"
|
|
"github.com/jackc/pgx/pgmock"
|
|
"github.com/jackc/pgx/stdlib"
|
|
)
|
|
|
|
func closeDB(t *testing.T, db *sql.DB) {
|
|
err := db.Close()
|
|
if err != nil {
|
|
t.Fatalf("db.Close unexpectedly failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Do a simple query to ensure the connection is still usable
|
|
func ensureConnValid(t *testing.T, db *sql.DB) {
|
|
var sum, rowCount int32
|
|
|
|
rows, err := db.Query("select generate_series(1,$1)", 10)
|
|
if err != nil {
|
|
t.Fatalf("db.Query failed: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var n int32
|
|
rows.Scan(&n)
|
|
sum += n
|
|
rowCount++
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
t.Fatalf("db.Query failed: %v", err)
|
|
}
|
|
|
|
if rowCount != 10 {
|
|
t.Error("Select called onDataRow wrong number of times")
|
|
}
|
|
if sum != 55 {
|
|
t.Error("Wrong values returned")
|
|
}
|
|
}
|
|
|
|
type preparer interface {
|
|
Prepare(query string) (*sql.Stmt, error)
|
|
}
|
|
|
|
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
|
|
stmt, err := p.Prepare(sql)
|
|
if err != nil {
|
|
t.Fatalf("%v Prepare unexpectedly failed: %v", p, err)
|
|
}
|
|
|
|
return stmt
|
|
}
|
|
|
|
func closeStmt(t *testing.T, stmt *sql.Stmt) {
|
|
err := stmt.Close()
|
|
if err != nil {
|
|
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSQLOpen(t *testing.T) {
|
|
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
|
|
if err != nil {
|
|
t.Fatalf("sql.Open failed: %v", err)
|
|
}
|
|
closeDB(t, db)
|
|
}
|
|
|
|
func TestNormalLifeCycle(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
|
defer closeStmt(t, stmt)
|
|
|
|
rows, err := stmt.Query(int32(1), int32(10))
|
|
if err != nil {
|
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
|
}
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
if err := rows.Scan(&s, &n); err != nil {
|
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
|
}
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
|
}
|
|
if rowCount != 10 {
|
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestStmtExec(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
t.Fatalf("db.Begin unexpectedly failed: %v", err)
|
|
}
|
|
|
|
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
|
|
_, err = createStmt.Exec()
|
|
if err != nil {
|
|
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
|
|
}
|
|
closeStmt(t, createStmt)
|
|
|
|
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
|
|
result, err := insertStmt.Exec("foo")
|
|
if err != nil {
|
|
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
n, err := result.RowsAffected()
|
|
if err != nil {
|
|
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Fatalf("Expected 1, received %d", n)
|
|
}
|
|
closeStmt(t, insertStmt)
|
|
|
|
if err != nil {
|
|
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestQueryCloseRowsEarly(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
|
defer closeStmt(t, stmt)
|
|
|
|
rows, err := stmt.Query(int32(1), int32(10))
|
|
if err != nil {
|
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
|
}
|
|
|
|
// Close rows immediately without having read them
|
|
err = rows.Close()
|
|
if err != nil {
|
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
|
}
|
|
|
|
// Run the query again to ensure the connection and statement are still ok
|
|
rows, err = stmt.Query(int32(1), int32(10))
|
|
if err != nil {
|
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
|
}
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
if err := rows.Scan(&s, &n); err != nil {
|
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
|
}
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
|
}
|
|
if rowCount != 10 {
|
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnExec(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(a varchar not null)")
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
result, err := db.Exec("insert into t values('hey')")
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
n, err := result.RowsAffected()
|
|
if err != nil {
|
|
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Fatalf("Expected 1, received %d", n)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnQuery(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
|
|
if err != nil {
|
|
t.Fatalf("db.Query unexpectedly failed: %v", err)
|
|
}
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
if err := rows.Scan(&s, &n); err != nil {
|
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
|
}
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
|
}
|
|
if rowCount != 10 {
|
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
type testLog struct {
|
|
lvl pgx.LogLevel
|
|
msg string
|
|
data map[string]interface{}
|
|
}
|
|
|
|
type testLogger struct {
|
|
logs []testLog
|
|
}
|
|
|
|
func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
|
}
|
|
|
|
func TestConnQueryNull(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
rows, err := db.Query("select $1::int", nil)
|
|
if err != nil {
|
|
t.Fatalf("db.Query unexpectedly failed: %v", err)
|
|
}
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var n sql.NullInt64
|
|
if err := rows.Scan(&n); err != nil {
|
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
|
}
|
|
if n.Valid != false {
|
|
t.Errorf("Expected n to be null, but it was %v", n)
|
|
}
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
|
}
|
|
if rowCount != 1 {
|
|
t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnQueryRowByteSlice(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
expected := []byte{222, 173, 190, 239}
|
|
var actual []byte
|
|
|
|
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
|
|
if err != nil {
|
|
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
|
|
}
|
|
|
|
if bytes.Compare(actual, expected) != 0 {
|
|
t.Fatalf("Expected %v, but got %v", expected, actual)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnQueryFailure(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Query("select 'foo")
|
|
if _, ok := err.(*pgconn.PgError); !ok {
|
|
t.Fatalf("Expected db.Query to return pgconn.PgError, but instead received: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
// Test type that pgx would handle natively in binary, but since it is not a
|
|
// database/sql native type should be passed through as a string
|
|
func TestConnQueryRowPgxBinary(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
sql := "select $1::int4[]"
|
|
expected := "{1,2,3}"
|
|
var actual string
|
|
|
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
|
if err != nil {
|
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
|
}
|
|
|
|
if actual != expected {
|
|
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnQueryRowUnknownType(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
sql := "select $1::point"
|
|
expected := "(1,2)"
|
|
var actual string
|
|
|
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
|
if err != nil {
|
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
|
}
|
|
|
|
if actual != expected {
|
|
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnQueryJSONIntoByteSlice(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec(`
|
|
create temporary table docs(
|
|
body json not null
|
|
);
|
|
|
|
insert into docs(body) values('{"foo":"bar"}');
|
|
`)
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
sql := `select * from docs`
|
|
expected := []byte(`{"foo":"bar"}`)
|
|
var actual []byte
|
|
|
|
err = db.QueryRow(sql).Scan(&actual)
|
|
if err != nil {
|
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
|
}
|
|
|
|
if bytes.Compare(actual, expected) != 0 {
|
|
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
|
|
}
|
|
|
|
_, err = db.Exec(`drop table docs`)
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec(`
|
|
create temporary table docs(
|
|
body json not null
|
|
);
|
|
`)
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
expected := []byte(`{"foo":"bar"}`)
|
|
|
|
_, err = db.Exec(`insert into docs(body) values($1)`, expected)
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
var actual []byte
|
|
err = db.QueryRow(`select body from docs`).Scan(&actual)
|
|
if err != nil {
|
|
t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
|
|
}
|
|
|
|
if bytes.Compare(actual, expected) != 0 {
|
|
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
|
|
}
|
|
|
|
_, err = db.Exec(`drop table docs`)
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestTransactionLifeCycle(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(a varchar not null)")
|
|
if err != nil {
|
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
t.Fatalf("db.Begin unexpectedly failed: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec("insert into t values('hi')")
|
|
if err != nil {
|
|
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
err = tx.Rollback()
|
|
if err != nil {
|
|
t.Fatalf("tx.Rollback unexpectedly failed: %v", err)
|
|
}
|
|
|
|
var n int64
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
if err != nil {
|
|
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
|
|
}
|
|
if n != 0 {
|
|
t.Fatalf("Expected 0 rows due to rollback, instead found %d", n)
|
|
}
|
|
|
|
tx, err = db.Begin()
|
|
if err != nil {
|
|
t.Fatalf("db.Begin unexpectedly failed: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec("insert into t values('hi')")
|
|
if err != nil {
|
|
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
|
|
}
|
|
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
if err != nil {
|
|
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Fatalf("Expected 1 rows due to rollback, instead found %d", n)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnBeginTxIsolation(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
var defaultIsoLevel string
|
|
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
|
|
if err != nil {
|
|
t.Fatalf("QueryRow failed: %v", err)
|
|
}
|
|
|
|
supportedTests := []struct {
|
|
sqlIso sql.IsolationLevel
|
|
pgIso string
|
|
}{
|
|
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
|
|
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
|
|
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
|
|
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
|
|
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
|
|
}
|
|
for i, tt := range supportedTests {
|
|
func() {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
|
if err != nil {
|
|
t.Errorf("%d. BeginTx failed: %v", i, err)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var pgIso string
|
|
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
|
|
if err != nil {
|
|
t.Errorf("%d. QueryRow failed: %v", i, err)
|
|
}
|
|
|
|
if pgIso != tt.pgIso {
|
|
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
|
|
}
|
|
}()
|
|
}
|
|
|
|
unsupportedTests := []struct {
|
|
sqlIso sql.IsolationLevel
|
|
}{
|
|
{sqlIso: sql.LevelWriteCommitted},
|
|
{sqlIso: sql.LevelLinearizable},
|
|
}
|
|
for i, tt := range unsupportedTests {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
|
if err == nil {
|
|
t.Errorf("%d. BeginTx should have failed", i)
|
|
tx.Rollback()
|
|
}
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnBeginTxReadOnly(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
|
|
if err != nil {
|
|
t.Fatalf("BeginTx failed: %v", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var pgReadOnly string
|
|
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
|
|
if err != nil {
|
|
t.Errorf("QueryRow failed: %v", err)
|
|
}
|
|
|
|
if pgReadOnly != "on" {
|
|
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestBeginTxContextCancel(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("drop table if exists t")
|
|
if err != nil {
|
|
t.Fatalf("db.Exec failed: %v", err)
|
|
}
|
|
|
|
ctx, cancelFn := context.WithCancel(context.Background())
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
t.Fatalf("BeginTx failed: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec("create table t(id serial)")
|
|
if err != nil {
|
|
t.Fatalf("tx.Exec failed: %v", err)
|
|
}
|
|
|
|
cancelFn()
|
|
|
|
err = tx.Commit()
|
|
if err != context.Canceled && err != sql.ErrTxDone {
|
|
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
|
|
}
|
|
|
|
var n int
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
|
|
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func acceptStandardPgxConn(backend *pgproto3.Backend) error {
|
|
script := pgmock.Script{
|
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
|
}
|
|
|
|
err := script.Run(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
typeScript := pgmock.Script{
|
|
Steps: pgmock.PgxInitSteps(),
|
|
}
|
|
|
|
return typeScript.Run(backend)
|
|
}
|
|
|
|
func TestAcquireConn(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
var conns []*pgx.Conn
|
|
|
|
for i := 1; i < 6; i++ {
|
|
conn, err := stdlib.AcquireConn(db)
|
|
if err != nil {
|
|
t.Errorf("%d. AcquireConn failed: %v", i, err)
|
|
continue
|
|
}
|
|
|
|
var n int32
|
|
err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
|
|
if err != nil {
|
|
t.Errorf("%d. QueryRow failed: %v", i, err)
|
|
}
|
|
if n != 1 {
|
|
t.Errorf("%d. n => %d, want %d", i, n, 1)
|
|
}
|
|
|
|
stats := db.Stats()
|
|
if stats.OpenConnections != i {
|
|
t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
|
|
}
|
|
|
|
conns = append(conns, conn)
|
|
}
|
|
|
|
for i, conn := range conns {
|
|
if err := stdlib.ReleaseConn(db, conn); err != nil {
|
|
t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
|
|
}
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnPingContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
if err := db.PingContext(context.Background()); err != nil {
|
|
t.Fatalf("db.PingContext failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnPrepareContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
stmt, err := db.PrepareContext(context.Background(), "select now()")
|
|
if err != nil {
|
|
t.Fatalf("db.PrepareContext failed: %v", err)
|
|
}
|
|
stmt.Close()
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnExecContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
|
|
if err != nil {
|
|
t.Fatalf("db.ExecContext failed: %v", err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnExecContextFailureRetry(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
// we get a connection, immediately close it, and then get it back
|
|
{
|
|
conn, err := stdlib.AcquireConn(db)
|
|
if err != nil {
|
|
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
|
|
}
|
|
conn.Close(context.Background())
|
|
stdlib.ReleaseConn(db, conn)
|
|
}
|
|
conn, err := db.Conn(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("db.Conn unexpectedly failed: %v", err)
|
|
}
|
|
if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn {
|
|
t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnQueryContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
|
|
if err != nil {
|
|
t.Fatalf("db.QueryContext failed: %v", err)
|
|
}
|
|
|
|
for rows.Next() {
|
|
var n int64
|
|
if err := rows.Scan(&n); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
t.Error(rows.Err())
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestConnQueryContextFailureRetry(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
// we get a connection, immediately close it, and then get it back
|
|
{
|
|
conn, err := stdlib.AcquireConn(db)
|
|
if err != nil {
|
|
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
|
|
}
|
|
conn.Close(context.Background())
|
|
stdlib.ReleaseConn(db, conn)
|
|
}
|
|
conn, err := db.Conn(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("db.Conn unexpectedly failed: %v", err)
|
|
}
|
|
if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn {
|
|
t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
rows, err := db.Query("select * from generate_series(1,10) n")
|
|
if err != nil {
|
|
t.Fatalf("db.Query failed: %v", err)
|
|
}
|
|
|
|
columnTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
t.Fatalf("rows.ColumnTypes failed: %v", err)
|
|
}
|
|
|
|
if len(columnTypes) != 1 {
|
|
t.Fatalf("len(columnTypes) => %v, want %v", len(columnTypes), 1)
|
|
}
|
|
|
|
if columnTypes[0].DatabaseTypeName() != "INT4" {
|
|
t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4")
|
|
}
|
|
|
|
rows.Close()
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestStmtExecContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(id int primary key)")
|
|
if err != nil {
|
|
t.Fatalf("db.Exec failed: %v", err)
|
|
}
|
|
|
|
stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.ExecContext(context.Background(), 42)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestStmtExecContextCancel(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(id int primary key)")
|
|
if err != nil {
|
|
t.Fatalf("db.Exec failed: %v", err)
|
|
}
|
|
|
|
stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = stmt.ExecContext(ctx, 42)
|
|
if err != context.DeadlineExceeded {
|
|
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestStmtQueryContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.QueryContext(context.Background(), 5)
|
|
if err != nil {
|
|
t.Fatalf("stmt.QueryContext failed: %v", err)
|
|
}
|
|
|
|
for rows.Next() {
|
|
var n int64
|
|
if err := rows.Scan(&n); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
t.Error(rows.Err())
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|
|
|
|
func TestRowsColumnTypes(t *testing.T) {
|
|
columnTypesTests := []struct {
|
|
Name string
|
|
TypeName string
|
|
Length struct {
|
|
Len int64
|
|
OK bool
|
|
}
|
|
DecimalSize struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}
|
|
ScanType reflect.Type
|
|
}{
|
|
{
|
|
Name: "a",
|
|
TypeName: "INT4",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(int32(0)),
|
|
}, {
|
|
Name: "bar",
|
|
TypeName: "TEXT",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: math.MaxInt64,
|
|
OK: true,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(""),
|
|
}, {
|
|
Name: "dec",
|
|
TypeName: "NUMERIC",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 9,
|
|
Scale: 2,
|
|
OK: true,
|
|
},
|
|
ScanType: reflect.TypeOf(float64(0)),
|
|
},
|
|
}
|
|
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
columns, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(columns) != 3 {
|
|
t.Errorf("expected 3 columns found %d", len(columns))
|
|
}
|
|
|
|
for i, tt := range columnTypesTests {
|
|
c := columns[i]
|
|
if c.Name() != tt.Name {
|
|
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
|
|
}
|
|
if c.DatabaseTypeName() != tt.TypeName {
|
|
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
|
|
}
|
|
l, ok := c.Length()
|
|
if l != tt.Length.Len {
|
|
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
|
|
}
|
|
if ok != tt.Length.OK {
|
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
|
|
}
|
|
p, s, ok := c.DecimalSize()
|
|
if p != tt.DecimalSize.Precision {
|
|
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
|
|
}
|
|
if s != tt.DecimalSize.Scale {
|
|
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
|
|
}
|
|
if ok != tt.DecimalSize.OK {
|
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
|
|
}
|
|
if c.ScanType() != tt.ScanType {
|
|
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/409
|
|
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
var msg json.RawMessage
|
|
|
|
err := db.QueryRow("select '{}'::json").Scan(&msg)
|
|
if err != nil {
|
|
t.Fatalf("QueryRow / Scan failed: %v", err)
|
|
}
|
|
|
|
if bytes.Compare([]byte("{}"), []byte(msg)) != 0 {
|
|
t.Fatalf("Expected %v, got %v", []byte("{}"), msg)
|
|
}
|
|
|
|
ensureConnValid(t, db)
|
|
}
|