pgx/pgtype/pgtype_test.go

642 lines
17 KiB
Go

package pgtype_test
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"reflect"
"regexp"
"strconv"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var defaultConnTestRunner pgxtest.ConnTestRunner
func init() {
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return config
}
}
// Test for renamed types
type _string string
type _bool bool
type _uint8 uint8
type _int8 int8
type _int16 int16
type _int16Slice []int16
type _int32Slice []int32
type _int64Slice []int64
type _float32Slice []float32
type _float64Slice []float64
type _byteSlice []byte
// unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type
// is not known (e.g. when using the simple protocol).
const unregisteredOID = uint32(1)
func mustParseInet(t testing.TB, s string) *net.IPNet {
ip, ipnet, err := net.ParseCIDR(s)
if err == nil {
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
} else {
ipnet.IP = ip
}
return ipnet
}
// May be bare IP address.
//
ip = net.ParseIP(s)
if ip == nil {
t.Fatal(errors.New("unable to parse inet address"))
}
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
ipnet.Mask = net.CIDRMask(32, 32)
}
return ipnet
}
func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr {
addr, err := net.ParseMAC(s)
if err != nil {
t.Fatal(err)
}
return addr
}
func skipCockroachDB(t testing.TB, msg string) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip(msg)
}
}
func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return
}
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
require.NoError(t, err)
if serverVersion < minVersion {
t.Skipf("Test requires PostgreSQL v%d+", minVersion)
}
}
// sqlScannerFunc lets an arbitrary function be used as a sql.Scanner.
type sqlScannerFunc func(src any) error
func (f sqlScannerFunc) Scan(src any) error {
return f(src)
}
// driverValuerFunc lets an arbitrary function be used as a driver.Valuer.
type driverValuerFunc func() (driver.Value, error)
func (f driverValuerFunc) Value() (driver.Value, error) {
return f()
}
func TestMapScanNilIsNoOp(t *testing.T) {
m := pgtype.NewMap()
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil)
assert.NoError(t, err)
}
func TestMapScanTextFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) {
m := pgtype.NewMap()
var got []byte
err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got)
require.NoError(t, err)
assert.Equal(t, []byte("{}"), got)
}
func TestMapScanBinaryFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) {
unknownOID := uint32(999999)
srcBuf := []byte("foo")
m := pgtype.NewMap()
var s string
err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s)
assert.NoError(t, err)
assert.Equal(t, "foo", s)
var rs _string
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
assert.Equal(t, "foo", string(rs))
var b []byte
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), b)
var rb _byteSlice
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), []byte(rb))
}
func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) {
m := pgtype.NewMap()
type myStruct struct{}
var p *myStruct
err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p)
require.NotNil(t, err)
}
func TestMapScanUnknownOIDTextFormat(t *testing.T) {
m := pgtype.NewMap()
var n int32
err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
assert.NoError(t, err)
assert.EqualValues(t, 123, n)
}
func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) {
m := pgtype.NewMap()
var s sql.NullString
err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s)
assert.NoError(t, err)
assert.Equal(t, "", s.String)
assert.False(t, s.Valid)
}
type scannerString string
func (ss *scannerString) Scan(v any) error {
*ss = scannerString("scanned")
return nil
}
// https://github.com/jackc/pgtype/issues/197
func TestMapScanUnregisteredOIDIntoRenamedStringSQLScanner(t *testing.T) {
m := pgtype.NewMap()
var s scannerString
err := m.Scan(unregisteredOID, pgx.TextFormatCode, []byte(nil), &s)
assert.NoError(t, err)
assert.Equal(t, "scanned", string(s))
}
type pgCustomInt int64
func (ci *pgCustomInt) Scan(src interface{}) error {
*ci = pgCustomInt(src.(int64))
return nil
}
func TestScanPlanBinaryInt32ScanScanner(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v pgCustomInt
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.EqualValues(t, 42, v)
ptr := new(pgCustomInt)
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = new(pgCustomInt)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
}
// Test for https://github.com/jackc/pgtype/issues/164
func TestScanPlanInterface(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v interface{}
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v)
err := plan.Scan(src, v)
assert.Error(t, err)
}
func TestPointerPointerStructScan(t *testing.T) {
m := pgtype.NewMap()
type composite struct {
ID int
}
int4Type, _ := m.TypeForOID(pgtype.Int4OID)
pgt := &pgtype.Type{
Codec: &pgtype.CompositeCodec{
Fields: []pgtype.CompositeCodecField{
{
Name: "id",
Type: int4Type,
},
},
},
Name: "composite",
OID: 215333,
}
m.RegisterType(pgt)
var c *composite
plan := m.PlanScan(pgt.OID, pgtype.TextFormatCode, &c)
err := plan.Scan([]byte("(1)"), &c)
require.NoError(t, err)
require.Equal(t, 1, c.ID)
}
// https://github.com/jackc/pgx/issues/1263
func TestMapScanPtrToPtrToSlice(t *testing.T) {
m := pgtype.NewMap()
src := []byte("{foo,bar}")
var v *[]string
plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.Equal(t, []string{"foo", "bar"}, *v)
}
func TestMapScanPtrToPtrToSliceOfStruct(t *testing.T) {
type Team struct {
TeamID int
Name string
}
// Have to use binary format because text format doesn't include type information.
m := pgtype.NewMap()
src := []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0xc9, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x19, 0x0, 0x0, 0x0, 0x6, 0x74, 0x65, 0x61, 0x6d, 0x20, 0x31, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x19, 0x0, 0x0, 0x0, 0x6, 0x74, 0x65, 0x61, 0x6d, 0x20, 0x32}
var v *[]Team
plan := m.PlanScan(pgtype.RecordArrayOID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.Equal(t, []Team{{1, "team 1"}, {2, "team 2"}}, *v)
}
type databaseValuerString string
func (s databaseValuerString) Value() (driver.Value, error) {
return fmt.Sprintf("%d", len(s)), nil
}
// https://github.com/jackc/pgx/issues/1319
func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerString("foo")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, "3", string(buf))
}
type databaseValuerFmtStringer string
func (s databaseValuerFmtStringer) Value() (driver.Value, error) {
return nil, nil
}
func (s databaseValuerFmtStringer) String() string {
return "foobar"
}
// https://github.com/jackc/pgx/issues/1311
func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerFmtStringer("")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Nil(t, buf)
}
type databaseValuerStringFormat struct {
n int32
}
func (v databaseValuerStringFormat) Value() (driver.Value, error) {
return fmt.Sprint(v.n), nil
}
func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerStringFormat{n: 42}
buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 42}, buf)
}
// https://github.com/jackc/pgx/issues/1445
func TestMapEncodeDatabaseValuerThatReturnsStringIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return "foo", nil }), nil)
require.NoError(t, err)
require.Equal(t, []byte("foo"), buf)
}
// https://github.com/jackc/pgx/issues/1445
func TestMapEncodeDatabaseValuerThatReturnsByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return []byte{0, 1, 2, 3}, nil }), nil)
require.NoError(t, err)
require.Equal(t, []byte(`\x00010203`), buf)
}
func TestMapEncodeStringIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, "foo", nil)
require.NoError(t, err)
require.Equal(t, []byte("foo"), buf)
}
func TestMapEncodeByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, []byte{0, 1, 2, 3}, nil)
require.NoError(t, err)
require.Equal(t, []byte(`\x00010203`), buf)
}
// https://github.com/jackc/pgx/issues/1763
func TestMapEncodeNamedTypeOfByteSliceIntoTextTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil)
require.NoError(t, err)
require.Equal(t, []byte(`{"foo": "bar"}`), buf)
}
// https://github.com/jackc/pgx/issues/1326
func TestMapScanPointerToRenamedType(t *testing.T) {
srcBuf := []byte("foo")
m := pgtype.NewMap()
var rs *_string
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
require.NotNil(t, rs)
assert.Equal(t, "foo", string(*rs))
}
// https://github.com/jackc/pgx/issues/1326
func TestMapScanNullToWrongType(t *testing.T) {
m := pgtype.NewMap()
var n *int32
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &n)
assert.NoError(t, err)
assert.Nil(t, n)
var pn pgtype.Int4
err = m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &pn)
assert.NoError(t, err)
assert.False(t, pn.Valid)
}
func TestScanToSliceOfRenamedUint8(t *testing.T) {
m := pgtype.NewMap()
var ruint8 []_uint8
err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8)
assert.NoError(t, err)
assert.Equal(t, []_uint8{2, 4}, ruint8)
}
func TestMapScanTextToBool(t *testing.T) {
tests := []struct {
name string
src []byte
want bool
}{
{"t", []byte("t"), true},
{"f", []byte("f"), false},
{"y", []byte("y"), true},
{"n", []byte("n"), false},
{"1", []byte("1"), true},
{"0", []byte("0"), false},
{"true", []byte("true"), true},
{"false", []byte("false"), false},
{"yes", []byte("yes"), true},
{"no", []byte("no"), false},
{"on", []byte("on"), true},
{"off", []byte("off"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := pgtype.NewMap()
var v bool
err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v)
require.NoError(t, err)
assert.Equal(t, tt.want, v)
})
}
}
func TestMapScanTextToBoolError(t *testing.T) {
tests := []struct {
name string
src []byte
want string
}{
{"nil", nil, "cannot scan NULL into *bool"},
{"empty", []byte{}, "cannot scan empty string into *bool"},
{"foo", []byte("foo"), "unknown boolean string representation \"foo\""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := pgtype.NewMap()
var v bool
err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v)
require.ErrorContains(t, err, tt.want)
})
}
}
type databaseValuerUUID [16]byte
func (v databaseValuerUUID) Value() (driver.Value, error) {
return fmt.Sprintf("%x", v), nil
}
// https://github.com/jackc/pgx/issues/1502
func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
expected := []byte{
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0xb, 0x86, 0, 0, 0, 2, 0, 0, 0, 1,
0, 0, 0, 16,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 0, 0, 16,
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
m := pgtype.NewMap()
buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,
[]databaseValuerUUID{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}},
nil)
require.NoError(t, err)
require.Equal(t, expected, buf)
// This actually *should* fail. In the actual query path this error is detected and the encoding falls back to the
// text format. In the bug this test is guarding against regression this would panic.
_, err = m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,
[]string{"00010203-0405-0607-0809-0a0b0c0d0e0f", "0f0e0d0c-0b0a-0908-0706-0504-03020100"},
nil)
require.Error(t, err)
}
// https://github.com/jackc/pgx/issues/1763
func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil)
require.NoError(t, err)
require.Equal(t, []byte(`{"foo": "bar"}`), buf)
}
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
for i := 0; i < b.N; i++ {
v = 0
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = 0
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func isExpectedEq(a any) func(any) bool {
return func(v any) bool {
return a == v
}
}
func isPtrExpectedEq(a any) func(any) bool {
return func(v any) bool {
val := reflect.ValueOf(v)
return a == val.Elem().Interface()
}
}