mirror of https://github.com/jackc/pgx.git
Merge pull request #2083 from sodahealth/xml-codec
V1 XMLCodec supports encoding + scanning XML column typepull/2088/head
commit
9530aea47b
|
@ -37,7 +37,7 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/1430
|
// https://github.com/jackc/pgx/issues/1430
|
||||||
//
|
//
|
||||||
// Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to beused
|
// Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to be used
|
||||||
// when both are implemented https://github.com/jackc/pgx/issues/1805
|
// when both are implemented https://github.com/jackc/pgx/issues/1805
|
||||||
case driver.Valuer:
|
case driver.Valuer:
|
||||||
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
|
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
|
||||||
|
@ -177,13 +177,6 @@ func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanJSONToBytesScanner struct{}
|
|
||||||
|
|
||||||
func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
|
|
||||||
scanner := (dst).(BytesScanner)
|
|
||||||
return scanner.ScanBytes(src)
|
|
||||||
}
|
|
||||||
|
|
||||||
type scanPlanJSONToJSONUnmarshal struct {
|
type scanPlanJSONToJSONUnmarshal struct {
|
||||||
unmarshal func(data []byte, v any) error
|
unmarshal func(data []byte, v any) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ const (
|
||||||
XIDOID = 28
|
XIDOID = 28
|
||||||
CIDOID = 29
|
CIDOID = 29
|
||||||
JSONOID = 114
|
JSONOID = 114
|
||||||
|
XMLOID = 142
|
||||||
|
XMLArrayOID = 143
|
||||||
JSONArrayOID = 199
|
JSONArrayOID = 199
|
||||||
PointOID = 600
|
PointOID = 600
|
||||||
LsegOID = 601
|
LsegOID = 601
|
||||||
|
|
|
@ -2,6 +2,7 @@ package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -89,6 +90,7 @@ func initDefaultMap() {
|
||||||
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})
|
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
||||||
|
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}})
|
||||||
|
|
||||||
// Range types
|
// Range types
|
||||||
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})
|
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})
|
||||||
|
@ -153,6 +155,7 @@ func initDefaultMap() {
|
||||||
defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}})
|
defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}})
|
||||||
defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}})
|
defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}})
|
||||||
defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}})
|
defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}})
|
||||||
|
defaultMap.RegisterType(&Type{Name: "_xml", OID: XMLArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XMLOID]}})
|
||||||
|
|
||||||
// Integer types that directly map to a PostgreSQL type
|
// Integer types that directly map to a PostgreSQL type
|
||||||
registerDefaultPgTypeVariants[int16](defaultMap, "int2")
|
registerDefaultPgTypeVariants[int16](defaultMap, "int2")
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type XMLCodec struct {
|
||||||
|
Marshal func(v any) ([]byte, error)
|
||||||
|
Unmarshal func(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*XMLCodec) FormatSupported(format int16) bool {
|
||||||
|
return format == TextFormatCode || format == BinaryFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*XMLCodec) PreferredFormat() int16 {
|
||||||
|
return TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *XMLCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||||
|
switch value.(type) {
|
||||||
|
case string:
|
||||||
|
return encodePlanXMLCodecEitherFormatString{}
|
||||||
|
case []byte:
|
||||||
|
return encodePlanXMLCodecEitherFormatByteSlice{}
|
||||||
|
|
||||||
|
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
|
||||||
|
//
|
||||||
|
// https://github.com/jackc/pgx/issues/1430
|
||||||
|
//
|
||||||
|
// Check for driver.Valuer must come before xml.Marshaler so that it is guaranteed to be used
|
||||||
|
// when both are implemented https://github.com/jackc/pgx/issues/1805
|
||||||
|
case driver.Valuer:
|
||||||
|
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
|
||||||
|
|
||||||
|
// Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be
|
||||||
|
// marshalled.
|
||||||
|
//
|
||||||
|
// https://github.com/jackc/pgx/issues/1681
|
||||||
|
case xml.Marshaler:
|
||||||
|
return &encodePlanXMLCodecEitherFormatMarshal{
|
||||||
|
marshal: c.Marshal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
|
||||||
|
// appropriate wrappers here.
|
||||||
|
for _, f := range []TryWrapEncodePlanFunc{
|
||||||
|
TryWrapDerefPointerEncodePlan,
|
||||||
|
TryWrapFindUnderlyingTypeEncodePlan,
|
||||||
|
} {
|
||||||
|
if wrapperPlan, nextValue, ok := f(value); ok {
|
||||||
|
if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil {
|
||||||
|
wrapperPlan.SetNext(nextPlan)
|
||||||
|
return wrapperPlan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &encodePlanXMLCodecEitherFormatMarshal{
|
||||||
|
marshal: c.Marshal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodePlanXMLCodecEitherFormatString struct{}
|
||||||
|
|
||||||
|
func (encodePlanXMLCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
xmlString := value.(string)
|
||||||
|
buf = append(buf, xmlString...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodePlanXMLCodecEitherFormatByteSlice struct{}
|
||||||
|
|
||||||
|
func (encodePlanXMLCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
xmlBytes := value.([]byte)
|
||||||
|
if xmlBytes == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = append(buf, xmlBytes...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodePlanXMLCodecEitherFormatMarshal struct {
|
||||||
|
marshal func(v any) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encodePlanXMLCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
xmlBytes, err := e.marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = append(buf, xmlBytes...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
switch target.(type) {
|
||||||
|
case *string:
|
||||||
|
return scanPlanAnyToString{}
|
||||||
|
|
||||||
|
case **string:
|
||||||
|
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
|
||||||
|
// solution would be.
|
||||||
|
//
|
||||||
|
// https://github.com/jackc/pgx/issues/1470 -- **string
|
||||||
|
// https://github.com/jackc/pgx/issues/1691 -- ** anything else
|
||||||
|
|
||||||
|
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
|
||||||
|
if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil {
|
||||||
|
if _, failed := nextPlan.(*scanPlanFail); !failed {
|
||||||
|
wrapperPlan.SetNext(nextPlan)
|
||||||
|
return wrapperPlan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case *[]byte:
|
||||||
|
return scanPlanXMLToByteSlice{}
|
||||||
|
case BytesScanner:
|
||||||
|
return scanPlanBinaryBytesToBytesScanner{}
|
||||||
|
|
||||||
|
// Cannot rely on sql.Scanner being handled later because scanPlanXMLToXMLUnmarshal will take precedence.
|
||||||
|
//
|
||||||
|
// https://github.com/jackc/pgx/issues/1418
|
||||||
|
case sql.Scanner:
|
||||||
|
return &scanPlanSQLScanner{formatCode: format}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &scanPlanXMLToXMLUnmarshal{
|
||||||
|
unmarshal: c.Unmarshal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanXMLToByteSlice struct{}
|
||||||
|
|
||||||
|
func (scanPlanXMLToByteSlice) Scan(src []byte, dst any) error {
|
||||||
|
dstBuf := dst.(*[]byte)
|
||||||
|
if src == nil {
|
||||||
|
*dstBuf = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*dstBuf = make([]byte, len(src))
|
||||||
|
copy(*dstBuf, src)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanXMLToXMLUnmarshal struct {
|
||||||
|
unmarshal func(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *scanPlanXMLToXMLUnmarshal) Scan(src []byte, dst any) error {
|
||||||
|
if src == nil {
|
||||||
|
dstValue := reflect.ValueOf(dst)
|
||||||
|
if dstValue.Kind() == reflect.Ptr {
|
||||||
|
el := dstValue.Elem()
|
||||||
|
switch el.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Struct:
|
||||||
|
el.Set(reflect.Zero(el.Type()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := reflect.ValueOf(dst).Elem()
|
||||||
|
elem.Set(reflect.Zero(elem.Type()))
|
||||||
|
|
||||||
|
return s.unmarshal(src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *XMLCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dstBuf := make([]byte, len(src))
|
||||||
|
copy(dstBuf, src)
|
||||||
|
return dstBuf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var dst any
|
||||||
|
err := c.Unmarshal(src, &dst)
|
||||||
|
return dst, err
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
package pgtype_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/xml"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
pgx "github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type xmlStruct struct {
|
||||||
|
XMLName xml.Name `xml:"person"`
|
||||||
|
Name string `xml:"name"`
|
||||||
|
Age int `xml:"age,attr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestXMLCodec(t *testing.T) {
|
||||||
|
skipCockroachDB(t, "CockroachDB does not support XML.")
|
||||||
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "xml", []pgxtest.ValueRoundTripTest{
|
||||||
|
{nil, new(*xmlStruct), isExpectedEq((*xmlStruct)(nil))},
|
||||||
|
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
||||||
|
{map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))},
|
||||||
|
{[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))},
|
||||||
|
{nil, new([]byte), isExpectedEqBytes([]byte(nil))},
|
||||||
|
|
||||||
|
// Test sql.Scanner.
|
||||||
|
{"", new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})},
|
||||||
|
|
||||||
|
// Test driver.Valuer.
|
||||||
|
{sql.NullString{String: "", Valid: true}, new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})},
|
||||||
|
})
|
||||||
|
|
||||||
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xml", []pgxtest.ValueRoundTripTest{
|
||||||
|
{[]byte(`<?xml version="1.0"?><Root></Root>`), new([]byte), isExpectedEqBytes([]byte(`<Root></Root>`))},
|
||||||
|
{[]byte(`<?xml version="1.0"?>`), new([]byte), isExpectedEqBytes([]byte(``))},
|
||||||
|
{[]byte(`<?xml version="1.0"?>`), new(string), isExpectedEq(``)},
|
||||||
|
{[]byte(`<Root></Root>`), new([]byte), isExpectedEqBytes([]byte(`<Root></Root>`))},
|
||||||
|
{[]byte(`<Root></Root>`), new(string), isExpectedEq(`<Root></Root>`)},
|
||||||
|
{[]byte(""), new([]byte), isExpectedEqBytes([]byte(""))},
|
||||||
|
{xmlStruct{Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})},
|
||||||
|
{xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})},
|
||||||
|
{[]byte(`<person age="10"><name>Adam</name></person>`), new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
||||||
|
func TestXMLCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
|
skipCockroachDB(t, "CockroachDB does not support XML.")
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
// Byte arrays are nilified
|
||||||
|
slice := []byte{10, 4}
|
||||||
|
err := conn.QueryRow(ctx, "select null::xml").Scan(&slice)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, slice)
|
||||||
|
|
||||||
|
// Non-pointer structs are zeroed
|
||||||
|
m := xmlStruct{Name: "Adam"}
|
||||||
|
err = conn.QueryRow(ctx, "select null::xml").Scan(&m)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, m)
|
||||||
|
|
||||||
|
// Pointers to structs are nilified
|
||||||
|
pm := &xmlStruct{Name: "Adam"}
|
||||||
|
err = conn.QueryRow(ctx, "select null::xml").Scan(&pm)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, pm)
|
||||||
|
|
||||||
|
// Pointer to pointer are nilified
|
||||||
|
n := ""
|
||||||
|
p := &n
|
||||||
|
err = conn.QueryRow(ctx, "select null::xml").Scan(&p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, p)
|
||||||
|
|
||||||
|
// A string cannot scan a NULL.
|
||||||
|
str := "foobar"
|
||||||
|
err = conn.QueryRow(ctx, "select null::xml").Scan(&str)
|
||||||
|
assert.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestXMLCodecPointerToPointerToString(t *testing.T) {
|
||||||
|
skipCockroachDB(t, "CockroachDB does not support XML.")
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
var s *string
|
||||||
|
err := conn.QueryRow(ctx, "select ''::xml").Scan(&s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, s)
|
||||||
|
require.Equal(t, "", *s)
|
||||||
|
|
||||||
|
err = conn.QueryRow(ctx, "select null::xml").Scan(&s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, s)
|
||||||
|
})
|
||||||
|
}
|
|
@ -795,6 +795,16 @@ func (r *Rows) Next(dest []driver.Value) error {
|
||||||
}
|
}
|
||||||
return d.Value()
|
return d.Value()
|
||||||
}
|
}
|
||||||
|
case pgtype.XMLOID:
|
||||||
|
var d []byte
|
||||||
|
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
||||||
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
||||||
|
err := scanPlan.Scan(src, &d)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
var d string
|
var d string
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
||||||
|
|
Loading…
Reference in New Issue