mirror of https://github.com/jackc/pgx.git
287 lines
7.8 KiB
Go
287 lines
7.8 KiB
Go
package testutil
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/jackc/pgx"
|
|
"github.com/jackc/pgx/pgtype"
|
|
_ "github.com/jackc/pgx/stdlib"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
|
var sqlDriverName string
|
|
switch driverName {
|
|
case "github.com/lib/pq":
|
|
sqlDriverName = "postgres"
|
|
case "github.com/jackc/pgx/stdlib":
|
|
sqlDriverName = "pgx"
|
|
default:
|
|
t.Fatalf("Unknown driver %v", driverName)
|
|
}
|
|
|
|
db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func MustConnectPgx(t testing.TB) *pgx.Conn {
|
|
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func MustClose(t testing.TB, conn interface {
|
|
Close() error
|
|
}) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func MustCloseContext(t testing.TB, conn interface {
|
|
Close(context.Context) error
|
|
}) {
|
|
err := conn.Close(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type forceTextEncoder struct {
|
|
e pgtype.TextEncoder
|
|
}
|
|
|
|
func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
return f.e.EncodeText(ci, buf)
|
|
}
|
|
|
|
type forceBinaryEncoder struct {
|
|
e pgtype.BinaryEncoder
|
|
}
|
|
|
|
func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
return f.e.EncodeBinary(ci, buf)
|
|
}
|
|
|
|
func ForceEncoder(e interface{}, formatCode int16) interface{} {
|
|
switch formatCode {
|
|
case pgx.TextFormatCode:
|
|
if e, ok := e.(pgtype.TextEncoder); ok {
|
|
return forceTextEncoder{e: e}
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if e, ok := e.(pgtype.BinaryEncoder); ok {
|
|
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
|
|
TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustCloseContext(t, conn)
|
|
|
|
_, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, v := range values {
|
|
for _, paramFormat := range formats {
|
|
for _, resultFormat := range formats {
|
|
vEncoder := ForceEncoder(v, paramFormat.formatCode)
|
|
if vEncoder == nil {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name)
|
|
continue
|
|
}
|
|
switch resultFormat.formatCode {
|
|
case pgx.TextFormatCode:
|
|
if _, ok := v.(pgtype.TextEncoder); !ok {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
|
|
continue
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if _, ok := v.(pgtype.BinaryEncoder); !ok {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
|
|
err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for i, v := range values {
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := ps.QueryRow(v).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
|
|
type NormalizeTest struct {
|
|
SQL string
|
|
Value interface{}
|
|
}
|
|
|
|
func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
|
|
TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustCloseContext(t, conn)
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
for _, fc := range formats {
|
|
psName := fmt.Sprintf("test%d", i)
|
|
ps, err := conn.Prepare(psName, tt.SQL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
|
if ForceEncoder(tt.Value, fc.formatCode) == nil {
|
|
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
|
|
continue
|
|
}
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = conn.QueryRow(context.Background(), psName).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
for i, tt := range tests {
|
|
ps, err := conn.Prepare(tt.SQL)
|
|
if err != nil {
|
|
t.Errorf("%d. %v", i, err)
|
|
continue
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = ps.QueryRow().Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|