pgx/pgtype/pgtype_test.go
Jack Christensen 38e09bda4c Fix *wrapSliceEncodePlan[T].Encode
It should pass a FlatArray[T] to the next step instead of a
anySliceArrayReflect. By using a anySliceArrayReflect, an encode of
[]github.com/google/uuid.UUID followed by []string into a PostgreSQL
uuid[] would crash. This was caused by a EncodePlan cache collision
where the second encoding used part of the cached plan of the first.

In proper usage a cache collision shouldn't be able to occur. If this
assertion proves incorrect it will be necessary to add an optional
interface to ScanPlan and EncodePlan that marks the plan as ineligable
for caching. But I have been unable to construct a failing case, and
given that ScanPlans have been cached for quite some time now without
incident I do not think it is possible. This issue only occurred due to
the bug in *wrapSliceEncodePlan[T].Encode.

https://github.com/jackc/pgx/issues/1502
2023-02-21 21:04:30 -06:00

547 lines
14 KiB
Go

package pgtype_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"os"
"regexp"
"strconv"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var defaultConnTestRunner pgxtest.ConnTestRunner
func init() {
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return config
}
}
// Test for renamed types
type _string string
type _bool bool
type _int8 int8
type _int16 int16
type _int16Slice []int16
type _int32Slice []int32
type _int64Slice []int64
type _float32Slice []float32
type _float64Slice []float64
type _byteSlice []byte
// unregisteredOID represents a actual type that is not registered. Cannot use 0 because that represents that the type
// is not known (e.g. when using the simple protocol).
const unregisteredOID = uint32(1)
func mustParseCIDR(t testing.TB, s string) *net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
t.Fatal(err)
}
return ipnet
}
func mustParseInet(t testing.TB, s string) *net.IPNet {
ip, ipnet, err := net.ParseCIDR(s)
if err == nil {
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
} else {
ipnet.IP = ip
}
return ipnet
}
// May be bare IP address.
//
ip = net.ParseIP(s)
if ip == nil {
t.Fatal(errors.New("unable to parse inet address"))
}
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
ipnet.Mask = net.CIDRMask(32, 32)
}
return ipnet
}
func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr {
addr, err := net.ParseMAC(s)
if err != nil {
t.Fatal(err)
}
return addr
}
func skipCockroachDB(t testing.TB, msg string) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip(msg)
}
}
func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return
}
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
require.NoError(t, err)
if serverVersion < minVersion {
t.Skipf("Test requires PostgreSQL v%d+", minVersion)
}
}
// sqlScannerFunc lets an arbitrary function be used as a sql.Scanner.
type sqlScannerFunc func(src any) error
func (f sqlScannerFunc) Scan(src any) error {
return f(src)
}
// driverValuerFunc lets an arbitrary function be used as a driver.Valuer.
type driverValuerFunc func() (driver.Value, error)
func (f driverValuerFunc) Value() (driver.Value, error) {
return f()
}
func TestMapScanNilIsNoOp(t *testing.T) {
m := pgtype.NewMap()
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil)
assert.NoError(t, err)
}
func TestMapScanTextFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) {
m := pgtype.NewMap()
var got []byte
err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got)
require.NoError(t, err)
assert.Equal(t, []byte("{}"), got)
}
func TestMapScanBinaryFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) {
unknownOID := uint32(999999)
srcBuf := []byte("foo")
m := pgtype.NewMap()
var s string
err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s)
assert.NoError(t, err)
assert.Equal(t, "foo", s)
var rs _string
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
assert.Equal(t, "foo", string(rs))
var b []byte
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), b)
var rb _byteSlice
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), []byte(rb))
}
func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) {
m := pgtype.NewMap()
type myStruct struct{}
var p *myStruct
err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p)
require.NotNil(t, err)
}
func TestMapScanUnknownOIDTextFormat(t *testing.T) {
m := pgtype.NewMap()
var n int32
err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
assert.NoError(t, err)
assert.EqualValues(t, 123, n)
}
func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) {
m := pgtype.NewMap()
var s sql.NullString
err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s)
assert.NoError(t, err)
assert.Equal(t, "", s.String)
assert.False(t, s.Valid)
}
type scannerString string
func (ss *scannerString) Scan(v any) error {
*ss = scannerString("scanned")
return nil
}
// https://github.com/jackc/pgtype/issues/197
func TestMapScanUnregisteredOIDIntoRenamedStringSQLScanner(t *testing.T) {
m := pgtype.NewMap()
var s scannerString
err := m.Scan(unregisteredOID, pgx.TextFormatCode, []byte(nil), &s)
assert.NoError(t, err)
assert.Equal(t, "scanned", string(s))
}
type pgCustomInt int64
func (ci *pgCustomInt) Scan(src interface{}) error {
*ci = pgCustomInt(src.(int64))
return nil
}
func TestScanPlanBinaryInt32ScanScanner(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v pgCustomInt
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.EqualValues(t, 42, v)
ptr := new(pgCustomInt)
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = new(pgCustomInt)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
}
// Test for https://github.com/jackc/pgtype/issues/164
func TestScanPlanInterface(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v interface{}
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v)
err := plan.Scan(src, v)
assert.Error(t, err)
}
func TestPointerPointerStructScan(t *testing.T) {
m := pgtype.NewMap()
type composite struct {
ID int
}
int4Type, _ := m.TypeForOID(pgtype.Int4OID)
pgt := &pgtype.Type{
Codec: &pgtype.CompositeCodec{
Fields: []pgtype.CompositeCodecField{
{
Name: "id",
Type: int4Type,
},
},
},
Name: "composite",
OID: 215333,
}
m.RegisterType(pgt)
var c *composite
plan := m.PlanScan(pgt.OID, pgtype.TextFormatCode, &c)
err := plan.Scan([]byte("(1)"), &c)
require.NoError(t, err)
require.Equal(t, c.ID, 1)
}
// https://github.com/jackc/pgx/issues/1263
func TestMapScanPtrToPtrToSlice(t *testing.T) {
m := pgtype.NewMap()
src := []byte("{foo,bar}")
var v *[]string
plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.Equal(t, []string{"foo", "bar"}, *v)
}
type databaseValuerString string
func (s databaseValuerString) Value() (driver.Value, error) {
return fmt.Sprintf("%d", len(s)), nil
}
// https://github.com/jackc/pgx/issues/1319
func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerString("foo")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, "3", string(buf))
}
type databaseValuerFmtStringer string
func (s databaseValuerFmtStringer) Value() (driver.Value, error) {
return nil, nil
}
func (s databaseValuerFmtStringer) String() string {
return "foobar"
}
// https://github.com/jackc/pgx/issues/1311
func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerFmtStringer("")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Nil(t, buf)
}
type databaseValuerStringFormat struct {
n int32
}
func (v databaseValuerStringFormat) Value() (driver.Value, error) {
return fmt.Sprint(v.n), nil
}
func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerStringFormat{n: 42}
buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 42}, buf)
}
// https://github.com/jackc/pgx/issues/1445
func TestMapEncodeDatabaseValuerThatReturnsStringIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return "foo", nil }), nil)
require.NoError(t, err)
require.Equal(t, []byte("foo"), buf)
}
// https://github.com/jackc/pgx/issues/1445
func TestMapEncodeDatabaseValuerThatReturnsByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return []byte{0, 1, 2, 3}, nil }), nil)
require.NoError(t, err)
require.Equal(t, []byte(`\x00010203`), buf)
}
func TestMapEncodeStringIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, "foo", nil)
require.NoError(t, err)
require.Equal(t, []byte("foo"), buf)
}
func TestMapEncodeByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, []byte{0, 1, 2, 3}, nil)
require.NoError(t, err)
require.Equal(t, []byte(`\x00010203`), buf)
}
// https://github.com/jackc/pgx/issues/1326
func TestMapScanPointerToRenamedType(t *testing.T) {
srcBuf := []byte("foo")
m := pgtype.NewMap()
var rs *_string
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
require.NotNil(t, rs)
assert.Equal(t, "foo", string(*rs))
}
// https://github.com/jackc/pgx/issues/1326
func TestMapScanNullToWrongType(t *testing.T) {
m := pgtype.NewMap()
var n *int32
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &n)
assert.NoError(t, err)
assert.Nil(t, n)
var pn pgtype.Int4
err = m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &pn)
assert.NoError(t, err)
assert.False(t, pn.Valid)
}
type databaseValuerUUID [16]byte
func (v databaseValuerUUID) Value() (driver.Value, error) {
return fmt.Sprintf("%x", v), nil
}
// https://github.com/jackc/pgx/issues/1502
func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
expected := []byte{
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0xb, 0x86, 0, 0, 0, 2, 0, 0, 0, 1,
0, 0, 0, 16,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 0, 0, 16,
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
m := pgtype.NewMap()
buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,
[]databaseValuerUUID{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}},
nil)
require.NoError(t, err)
require.Equal(t, expected, buf)
// This actually *should* fail. In the actual query path this error is detected and the encoding falls back to the
// text format. In the bug this test is guarding against regression this would panic.
_, err = m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,
[]string{"00010203-0405-0607-0809-0a0b0c0d0e0f", "0f0e0d0c-0b0a-0908-0706-0504-03020100"},
nil)
require.Error(t, err)
}
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
for i := 0; i < b.N; i++ {
v = 0
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = 0
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func isExpectedEq(a any) func(any) bool {
return func(v any) bool {
return a == v
}
}