mirror of
https://github.com/jackc/pgx.git
synced 2025-05-17 13:01:05 +00:00
PGX_TEST_DATABASE is much less likely to collide with another environment variable. This is especially valuable when using direnv to automatically set environment variables.
299 lines
7.8 KiB
Go
299 lines
7.8 KiB
Go
package testutil
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/jackc/pgx"
|
|
"github.com/jackc/pgx/pgtype"
|
|
_ "github.com/jackc/pgx/stdlib"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
|
var sqlDriverName string
|
|
switch driverName {
|
|
case "github.com/lib/pq":
|
|
sqlDriverName = "postgres"
|
|
case "github.com/jackc/pgx/stdlib":
|
|
sqlDriverName = "pgx"
|
|
default:
|
|
t.Fatalf("Unknown driver %v", driverName)
|
|
}
|
|
|
|
db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func MustConnectPgx(t testing.TB) *pgx.Conn {
|
|
config, err := pgx.ParseURI(os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
conn, err := pgx.Connect(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func MustClose(t testing.TB, conn interface {
|
|
Close() error
|
|
}) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type forceTextEncoder struct {
|
|
e pgtype.TextEncoder
|
|
}
|
|
|
|
func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
|
return f.e.EncodeText(ci, w)
|
|
}
|
|
|
|
type forceBinaryEncoder struct {
|
|
e pgtype.BinaryEncoder
|
|
}
|
|
|
|
func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
|
return f.e.EncodeBinary(ci, w)
|
|
}
|
|
|
|
func ForceEncoder(e interface{}, formatCode int16) interface{} {
|
|
switch formatCode {
|
|
case pgx.TextFormatCode:
|
|
if e, ok := e.(pgtype.TextEncoder); ok {
|
|
return forceTextEncoder{e: e}
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if e, ok := e.(pgtype.BinaryEncoder); ok {
|
|
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
|
|
TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
|
TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustClose(t, conn)
|
|
|
|
ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, v := range values {
|
|
for _, fc := range formats {
|
|
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
|
vEncoder := ForceEncoder(v, fc.formatCode)
|
|
if vEncoder == nil {
|
|
t.Logf("Skipping: %#v does not implement %v", v, fc.name)
|
|
continue
|
|
}
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustClose(t, conn)
|
|
|
|
for i, v := range values {
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := conn.QueryRowEx(
|
|
context.Background(),
|
|
fmt.Sprintf("select ($1)::%s", pgTypeName),
|
|
&pgx.QueryExOptions{SimpleProtocol: true},
|
|
v,
|
|
).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("Simple protocol %d: %v", i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for i, v := range values {
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := ps.QueryRow(v).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
|
|
type NormalizeTest struct {
|
|
SQL string
|
|
Value interface{}
|
|
}
|
|
|
|
func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
|
|
TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustClose(t, conn)
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
for _, fc := range formats {
|
|
psName := fmt.Sprintf("test%d", i)
|
|
ps, err := conn.Prepare(psName, tt.SQL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
|
if ForceEncoder(tt.Value, fc.formatCode) == nil {
|
|
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
|
|
continue
|
|
}
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = conn.QueryRow(psName).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
for i, tt := range tests {
|
|
ps, err := conn.Prepare(tt.SQL)
|
|
if err != nil {
|
|
t.Errorf("%d. %v", i, err)
|
|
continue
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = ps.QueryRow().Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|