mirror of https://github.com/jackc/pgx.git
Add CompositeFields encoders
parent
e92ee69901
commit
2186634638
|
@ -1,11 +1,17 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a
|
||||
// nullable value use a *CompositeFields. It will be set to nil in case of null.
|
||||
//
|
||||
// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not
|
||||
// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType.
|
||||
type CompositeFields []interface{}
|
||||
|
||||
func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
|
@ -74,3 +80,109 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using
|
||||
// CompositeFields to encode directly.
|
||||
func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
buf = append(buf, '(')
|
||||
|
||||
fieldBuf := make([]byte, 0, 32)
|
||||
|
||||
for _, f := range cf {
|
||||
if f != nil {
|
||||
fieldBuf = fieldBuf[0:0]
|
||||
if textEncoder, ok := f.(TextEncoder); ok {
|
||||
var err error
|
||||
fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
} else {
|
||||
dt, ok := ci.DataTypeForValue(f)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Unknown data type for %#v", f)
|
||||
}
|
||||
|
||||
err := dt.Value.Set(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if textEncoder, ok := dt.Value.(TextEncoder); ok {
|
||||
var err error
|
||||
fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("Cannot encode text format for %v", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
buf[len(buf)-1] = ')'
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is
|
||||
// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary
|
||||
// composite format requires the OID of each field to be specified the only types that will work are those known to
|
||||
// ConnInfo.
|
||||
//
|
||||
// In particular:
|
||||
//
|
||||
// * Nil cannot be used because there is no way to determine what type it.
|
||||
// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail.
|
||||
// * No dereferencing will be done. e.g. *Text must be used instead of Text.
|
||||
func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
buf = pgio.AppendUint32(buf, uint32(len(cf)))
|
||||
|
||||
for _, f := range cf {
|
||||
dt, ok := ci.DataTypeForValue(f)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Unknown OID for %#v", f)
|
||||
}
|
||||
|
||||
buf = pgio.AppendUint32(buf, dt.OID)
|
||||
lengthPos := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
if binaryEncoder, ok := f.(BinaryEncoder); ok {
|
||||
fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf)))
|
||||
buf = fieldBuf
|
||||
}
|
||||
} else {
|
||||
err := dt.Value.Set(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok {
|
||||
fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf)))
|
||||
buf = fieldBuf
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("Cannot encode binary format for %v", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/jackc/pgtype/testutil"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeFieldsDecode(t *testing.T) {
|
||||
|
@ -123,4 +124,150 @@ func TestCompositeFieldsDecode(t *testing.T) {
|
|||
assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip nil fields
|
||||
{
|
||||
var a int32
|
||||
var c float64
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, nil, &c},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EqualValuesf(t, 1, a, "Format: %v", format)
|
||||
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeFieldsEncode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists cf_encode;
|
||||
|
||||
create type cf_encode as (
|
||||
a text,
|
||||
b int4,
|
||||
c text,
|
||||
d float8,
|
||||
e text
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type cf_encode")
|
||||
|
||||
// Use simple protocol to force text or binary encoding
|
||||
simpleProtocols := []bool{true, false}
|
||||
|
||||
// Assorted values
|
||||
{
|
||||
var a string
|
||||
var b int32
|
||||
var c string
|
||||
var d float64
|
||||
var e string
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// untyped nil
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b int32
|
||||
var c string
|
||||
var d pgtype.Float8
|
||||
var e pgtype.Text
|
||||
|
||||
simpleProtocol := true
|
||||
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
|
||||
// untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema
|
||||
// of the composite type.
|
||||
simpleProtocol = false
|
||||
err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
|
||||
// nulls, string "null", and empty string fields
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b int32
|
||||
var c string
|
||||
var d pgtype.Float8
|
||||
var e pgtype.Text
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// quotes and special characters
|
||||
{
|
||||
var a string
|
||||
var b int32
|
||||
var c string
|
||||
var d float64
|
||||
var e string
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
`select $1::cf_encode`,
|
||||
pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package pgtype
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
errors "golang.org/x/xerrors"
|
||||
|
@ -366,3 +367,16 @@ func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte {
|
|||
func RecordAddNull(buf []byte, oid uint32) []byte {
|
||||
return pgio.AppendInt32(buf, int32(-1))
|
||||
}
|
||||
|
||||
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
||||
|
||||
func quoteCompositeField(src string) string {
|
||||
return `"` + quoteCompositeReplacer.Replace(src) + `"`
|
||||
}
|
||||
|
||||
func QuoteCompositeFieldIfNeeded(src string) string {
|
||||
if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
|
||||
return quoteCompositeField(src)
|
||||
}
|
||||
return src
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue