mirror of https://github.com/jackc/pgx.git
130 lines
2.7 KiB
Go
130 lines
2.7 KiB
Go
package pgtype_test
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/jackc/pgx"
|
|
"github.com/jackc/pgx/pgtype"
|
|
)
|
|
|
|
// Test for renamed types
|
|
type _string string
|
|
type _bool bool
|
|
type _int8 int8
|
|
type _int16 int16
|
|
type _int16Slice []int16
|
|
type _int32Slice []int32
|
|
type _int64Slice []int64
|
|
type _float32Slice []float32
|
|
type _float64Slice []float64
|
|
|
|
func mustConnectPgx(t testing.TB) *pgx.Conn {
|
|
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
conn, err := pgx.Connect(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func mustClose(t testing.TB, conn interface {
|
|
Close() error
|
|
}) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func mustParseCIDR(t testing.TB, s string) *net.IPNet {
|
|
_, ipnet, err := net.ParseCIDR(s)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return ipnet
|
|
}
|
|
|
|
type forceTextEncoder struct {
|
|
e pgtype.TextEncoder
|
|
}
|
|
|
|
func (f forceTextEncoder) EncodeText(w io.Writer) error {
|
|
return f.e.EncodeText(w)
|
|
}
|
|
|
|
type forceBinaryEncoder struct {
|
|
e pgtype.BinaryEncoder
|
|
}
|
|
|
|
func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error {
|
|
return f.e.EncodeBinary(w)
|
|
}
|
|
|
|
func forceEncoder(e interface{}, formatCode int16) interface{} {
|
|
switch formatCode {
|
|
case pgx.TextFormatCode:
|
|
return forceTextEncoder{e: e.(pgtype.TextEncoder)}
|
|
case pgx.BinaryFormatCode:
|
|
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
|
|
default:
|
|
panic("bad encoder")
|
|
}
|
|
}
|
|
|
|
func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
|
|
testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := mustConnectPgx(t)
|
|
defer mustClose(t, conn)
|
|
|
|
ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for _, fc := range formats {
|
|
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
|
for i, v := range values {
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|