Temporarily remove composite and record support

query-exec-mode
Jack Christensen 2022-01-01 11:41:08 -06:00
parent ffa1fdd66e
commit 40fb889605
9 changed files with 1 additions and 2020 deletions

View File

@ -1,192 +0,0 @@
package pgtype_test
import (
"testing"
"github.com/jackc/pgio"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/require"
)
type MyCompositeRaw struct {
A int32
B *string
}
func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
buf = pgio.AppendUint32(buf, 2)
buf = pgio.AppendUint32(buf, pgtype.Int4OID)
buf = pgio.AppendInt32(buf, 4)
buf = pgio.AppendInt32(buf, src.A)
buf = pgio.AppendUint32(buf, pgtype.TextOID)
if src.B != nil {
buf = pgio.AppendInt32(buf, int32(len(*src.B)))
buf = append(buf, (*src.B)...)
} else {
buf = pgio.AppendInt32(buf, -1)
}
return buf, nil
}
func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
a := pgtype.Int4{}
b := pgtype.Text{}
scanner := pgtype.NewCompositeBinaryScanner(ci, src)
scanner.ScanDecoder(&a)
scanner.ScanDecoder(&b)
if scanner.Err() != nil {
return scanner.Err()
}
dst.A = a.Int
if b.Valid {
dst.B = &b.String
} else {
dst.B = nil
}
return nil
}
var x []byte
func BenchmarkBinaryEncodingManual(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
v := MyCompositeRaw{4, ptrS("ABCDEFG")}
b.ResetTimer()
for n := 0; n < b.N; n++ {
buf, _ = v.EncodeBinary(ci, buf[:0])
}
x = buf
}
func BenchmarkBinaryEncodingHelper(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
v := MyType{4, ptrS("ABCDEFG")}
b.ResetTimer()
for n := 0; n < b.N; n++ {
buf, _ = v.EncodeBinary(ci, buf[:0])
}
x = buf
}
func BenchmarkBinaryEncodingComposite(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
f1 := 2
f2 := ptrS("bar")
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, ci)
require.NoError(b, err)
b.ResetTimer()
for n := 0; n < b.N; n++ {
c.Set([]interface{}{f1, f2})
buf, _ = c.EncodeBinary(ci, buf[:0])
}
x = buf
}
func BenchmarkBinaryEncodingJSON(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
v := MyCompositeRaw{4, ptrS("ABCDEFG")}
j := pgtype.JSON{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
j.Set(v)
buf, _ = j.EncodeBinary(ci, buf[:0])
}
x = buf
}
var dstRaw MyCompositeRaw
func BenchmarkBinaryDecodingManual(b *testing.B) {
ci := pgtype.NewConnInfo()
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
dst := MyCompositeRaw{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := dst.DecodeBinary(ci, buf)
E(err)
}
dstRaw = dst
}
var dstMyType MyType
func BenchmarkBinaryDecodingHelpers(b *testing.B) {
ci := pgtype.NewConnInfo()
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
dst := MyType{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := dst.DecodeBinary(ci, buf)
E(err)
}
dstMyType = dst
}
var gf1 int
var gf2 *string
func BenchmarkBinaryDecodingCompositeScan(b *testing.B) {
ci := pgtype.NewConnInfo()
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
var f1 int
var f2 *string
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, ci)
require.NoError(b, err)
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := c.DecodeBinary(ci, buf)
if err != nil {
b.Fatal(err)
}
err = c.AssignTo([]interface{}{&f1, &f2})
if err != nil {
b.Fatal(err)
}
}
gf1 = f1
gf2 = f2
}
func BenchmarkBinaryDecodingJSON(b *testing.B) {
ci := pgtype.NewConnInfo()
j := pgtype.JSON{}
j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")})
buf, _ := j.EncodeBinary(ci, nil)
j = pgtype.JSON{}
dst := MyCompositeRaw{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := j.DecodeBinary(ci, buf)
E(err)
err = j.AssignTo(&dst)
E(err)
}
dstRaw = dst
}

View File

@ -1,107 +0,0 @@
package pgtype
import "fmt"
// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a
// nullable value use a *CompositeFields. It will be set to nil in case of null.
//
// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not
// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType.
type CompositeFields []interface{}
func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
if len(cf) == 0 {
return fmt.Errorf("cannot decode into empty CompositeFields")
}
if src == nil {
return fmt.Errorf("cannot decode unexpected null into CompositeFields")
}
scanner := NewCompositeBinaryScanner(ci, src)
for _, f := range cf {
scanner.ScanValue(f)
}
if scanner.Err() != nil {
return scanner.Err()
}
return nil
}
func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
if len(cf) == 0 {
return fmt.Errorf("cannot decode into empty CompositeFields")
}
if src == nil {
return fmt.Errorf("cannot decode unexpected null into CompositeFields")
}
scanner := NewCompositeTextScanner(ci, src)
for _, f := range cf {
scanner.ScanValue(f)
}
if scanner.Err() != nil {
return scanner.Err()
}
return nil
}
// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using
// CompositeFields to encode directly.
func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
b := NewCompositeTextBuilder(ci, buf)
for _, f := range cf {
if paramEncoder, ok := f.(ParamEncoder); ok {
b.AppendEncoder(paramEncoder)
} else {
b.AppendValue(f)
}
}
return b.Finish()
}
// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is
// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary
// composite format requires the OID of each field to be specified the only types that will work are those known to
// ConnInfo.
//
// In particular:
//
// * Nil cannot be used because there is no way to determine what type it.
// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail.
// * No dereferencing will be done. e.g. *Text must be used instead of Text.
func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
b := NewCompositeBinaryBuilder(ci, buf)
for _, f := range cf {
dt, ok := ci.DataTypeForValue(f)
if !ok {
return nil, fmt.Errorf("Unknown OID for %#v", f)
}
if paramEncoder, ok := f.(ParamEncoder); ok {
b.AppendEncoder(dt.OID, paramEncoder)
} else {
err := dt.Value.Set(f)
if err != nil {
return nil, err
}
if paramEncoder, ok := dt.Value.(ParamEncoder); ok {
b.AppendEncoder(dt.OID, paramEncoder)
} else {
return nil, fmt.Errorf("Cannot encode binary format for %v", f)
}
}
}
return b.Finish()
}

View File

@ -1,273 +0,0 @@
package pgtype_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCompositeFieldsDecode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}
// Assorted values
{
var a int32
var b string
var c float64
for _, format := range formats {
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b, &c},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.EqualValuesf(t, 1, a, "Format: %v", format)
assert.EqualValuesf(t, "hi", b, "Format: %v", format)
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
}
}
// nulls, string "null", and empty string fields
{
var a pgtype.Text
var b string
var c pgtype.Text
var d string
var e pgtype.Text
for _, format := range formats {
err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.Nilf(t, a.Get(), "Format: %v", format)
assert.EqualValuesf(t, "null", b, "Format: %v", format)
assert.Nilf(t, c.Get(), "Format: %v", format)
assert.EqualValuesf(t, "", d, "Format: %v", format)
assert.Nilf(t, e.Get(), "Format: %v", format)
}
}
// null record
{
var a pgtype.Text
var b string
cf := pgtype.CompositeFields{&a, &b}
for _, format := range formats {
// Cannot scan nil into
err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
cf,
)
if assert.Errorf(t, err, "Format: %v", format) {
continue
}
assert.NotNilf(t, cf, "Format: %v", format)
// But can scan nil into *pgtype.CompositeFields
err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
&cf,
)
if assert.Errorf(t, err, "Format: %v", format) {
continue
}
assert.Nilf(t, cf, "Format: %v", format)
}
}
// quotes and special characters
{
var a, b, c, d string
for _, format := range formats {
err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b, &c, &d},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.Equalf(t, `"`, a, "Format: %v", format)
assert.Equalf(t, `foo bar`, b, "Format: %v", format)
assert.Equalf(t, `foo'bar`, c, "Format: %v", format)
assert.Equalf(t, `baz)bar`, d, "Format: %v", format)
}
}
// arrays
{
var a []string
var b []int64
for _, format := range formats {
err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format)
assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
}
}
// Skip nil fields
{
var a int32
var c float64
for _, format := range formats {
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, nil, &c},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.EqualValuesf(t, 1, a, "Format: %v", format)
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
}
}
}
func TestCompositeFieldsEncode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Exec(context.Background(), `drop type if exists cf_encode;
create type cf_encode as (
a text,
b int4,
c text,
d float8,
e text
);`)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type cf_encode")
// Use simple protocol to force text or binary encoding
simpleProtocols := []bool{true, false}
// Assorted values
{
var a string
var b int32
var c string
var d float64
var e string
for _, simpleProtocol := range simpleProtocols {
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"},
).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol)
}
}
}
// untyped nil
{
var a pgtype.Text
var b int32
var c string
var d pgtype.Float8
var e pgtype.Text
simpleProtocol := true
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
}
// untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema
// of the composite type.
simpleProtocol = false
err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol)
}
// nulls, string "null", and empty string fields
{
var a pgtype.Text
var b int32
var c string
var d pgtype.Float8
var e pgtype.Text
for _, simpleProtocol := range simpleProtocols {
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{&pgtype.Text{}, int32(1), "null", &pgtype.Float8{}, &pgtype.Text{}},
).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
}
}
}
// quotes and special characters
{
var a string
var b int32
var c string
var d float64
var e string
for _, simpleProtocol := range simpleProtocols {
err := conn.QueryRow(
context.Background(),
`select $1::cf_encode`,
pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`},
).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol)
assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol)
assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol)
assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol)
assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol)
}
}
}
}

View File

@ -1,715 +0,0 @@
package pgtype
import (
"encoding/binary"
"errors"
"fmt"
"reflect"
"strings"
"github.com/jackc/pgio"
)
type CompositeTypeField struct {
Name string
OID uint32
}
type CompositeType struct {
valid bool
typeName string
fields []CompositeTypeField
valueTranscoders []ValueTranscoder
}
// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used
// for fields. All field OIDs must be previously registered in ci.
func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
valueTranscoders := make([]ValueTranscoder, len(fields))
for i := range fields {
dt, ok := ci.DataTypeForOID(fields[i].OID)
if !ok {
return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID)
}
value := NewValue(dt.Value)
valueTranscoder, ok := value.(ValueTranscoder)
if !ok {
return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
}
valueTranscoders[i] = valueTranscoder
}
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
}
// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length.
// Prefer NewCompositeType unless overriding the transcoding of fields is required.
func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
if len(fields) != len(values) {
return nil, errors.New("fields and valueTranscoders must have same length")
}
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
}
func (src CompositeType) Get() interface{} {
if !src.valid {
return nil
}
results := make(map[string]interface{}, len(src.valueTranscoders))
for i := range src.valueTranscoders {
results[src.fields[i].Name] = src.valueTranscoders[i].Get()
}
return results
}
func (ct *CompositeType) NewTypeValue() Value {
a := &CompositeType{
typeName: ct.typeName,
fields: ct.fields,
valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
}
for i := range ct.valueTranscoders {
a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
}
return a
}
func (ct *CompositeType) TypeName() string {
return ct.typeName
}
func (ct *CompositeType) Fields() []CompositeTypeField {
return ct.fields
}
func (dst *CompositeType) setNil() {
dst.valid = false
}
func (dst *CompositeType) Set(src interface{}) error {
if src == nil {
dst.setNil()
return nil
}
switch value := src.(type) {
case []interface{}:
if len(value) != len(dst.valueTranscoders) {
return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
}
for i, v := range value {
if err := dst.valueTranscoders[i].Set(v); err != nil {
return err
}
}
dst.valid = true
case *[]interface{}:
if value == nil {
dst.setNil()
return nil
}
return dst.Set(*value)
default:
return fmt.Errorf("Can not convert %v to Composite", src)
}
return nil
}
// AssignTo should never be called on composite value directly
func (src CompositeType) AssignTo(dst interface{}) error {
if !src.valid {
return NullAssignTo(dst)
}
switch v := dst.(type) {
case []interface{}:
if len(v) != len(src.valueTranscoders) {
return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
}
for i := range src.valueTranscoders {
if v[i] == nil {
continue
}
err := assignToOrSet(src.valueTranscoders[i], v[i])
if err != nil {
return fmt.Errorf("unable to assign to dst[%d]: %v", i, err)
}
}
return nil
case *[]interface{}:
return src.AssignTo(*v)
default:
if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct {
return err
}
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
return fmt.Errorf("unable to assign to %T", dst)
}
}
func assignToOrSet(src Value, dst interface{}) error {
assignToErr := src.AssignTo(dst)
if assignToErr != nil {
// Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self.
setSucceeded := false
if setter, ok := dst.(Value); ok {
err := setter.Set(src.Get())
setSucceeded = err == nil
}
if !setSucceeded {
return assignToErr
}
}
return nil
}
func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Ptr {
return false, nil
}
if dstValue.IsNil() {
return false, nil
}
dstElemValue := dstValue.Elem()
dstElemType := dstElemValue.Type()
if dstElemType.Kind() != reflect.Struct {
return false, nil
}
exportedFields := make([]int, 0, dstElemType.NumField())
for i := 0; i < dstElemType.NumField(); i++ {
sf := dstElemType.Field(i)
if sf.PkgPath == "" {
exportedFields = append(exportedFields, i)
}
}
if len(exportedFields) != len(src.valueTranscoders) {
return false, nil
}
for i := range exportedFields {
err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
if err != nil {
return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
}
}
return true, nil
}
func (ct *CompositeType) BinaryFormatSupported() bool {
for _, vt := range ct.valueTranscoders {
if !vt.BinaryFormatSupported() {
return false
}
}
return true
}
func (ct *CompositeType) TextFormatSupported() bool {
for _, vt := range ct.valueTranscoders {
if !vt.TextFormatSupported() {
return false
}
}
return true
}
func (ct *CompositeType) PreferredFormat() int16 {
if ct.BinaryFormatSupported() {
return BinaryFormatCode
}
return TextFormatCode
}
func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
if src == nil {
dst.setNil()
return nil
}
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return dst.DecodeText(ci, src)
}
return fmt.Errorf("unknown format code %d", format)
}
func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
switch format {
case BinaryFormatCode:
return src.EncodeBinary(ci, buf)
case TextFormatCode:
return src.EncodeText(ci, buf)
}
return nil, fmt.Errorf("unknown format code %d", format)
}
func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
if !src.valid {
return nil, nil
}
b := NewCompositeBinaryBuilder(ci, buf)
for i := range src.valueTranscoders {
b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
}
return b.Finish()
}
// DecodeBinary implements BinaryDecoder interface.
// Opposite to Record, fields in a composite act as a "schema"
// and decoding fails if SQL value can't be assigned due to
// type mismatch
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
scanner := NewCompositeBinaryScanner(ci, buf)
for _, f := range dst.valueTranscoders {
scanner.ScanDecoder(f)
}
if scanner.Err() != nil {
return scanner.Err()
}
dst.valid = true
return nil
}
func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
scanner := NewCompositeTextScanner(ci, buf)
for _, f := range dst.valueTranscoders {
scanner.ScanDecoder(f)
}
if scanner.Err() != nil {
return scanner.Err()
}
dst.valid = true
return nil
}
func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
if !src.valid {
return nil, nil
}
b := NewCompositeTextBuilder(ci, buf)
for _, f := range src.valueTranscoders {
b.AppendEncoder(f)
}
return b.Finish()
}
type CompositeBinaryScanner struct {
ci *ConnInfo
rp int
src []byte
fieldCount int32
fieldBytes []byte
fieldOID uint32
err error
}
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
rp := 0
if len(src[rp:]) < 4 {
return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
}
fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
return &CompositeBinaryScanner{
ci: ci,
rp: rp,
src: src,
fieldCount: fieldCount,
}
}
// ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// ScanDecoder calls Next and scans the result into d.
func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeBinaryScanner) Next() bool {
if cfs.err != nil {
return false
}
if cfs.rp == len(cfs.src) {
return false
}
if len(cfs.src[cfs.rp:]) < 8 {
cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
return false
}
cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
cfs.rp += 4
fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
cfs.rp += 4
if fieldLen >= 0 {
if len(cfs.src[cfs.rp:]) < fieldLen {
cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
return false
}
cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
cfs.rp += fieldLen
} else {
cfs.fieldBytes = nil
}
return true
}
func (cfs *CompositeBinaryScanner) FieldCount() int {
return int(cfs.fieldCount)
}
// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) Bytes() []byte {
return cfs.fieldBytes
}
// OID returns the OID of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) OID() uint32 {
return cfs.fieldOID
}
// Err returns any error encountered by the scanner.
func (cfs *CompositeBinaryScanner) Err() error {
return cfs.err
}
type CompositeTextScanner struct {
ci *ConnInfo
rp int
src []byte
fieldBytes []byte
err error
}
// NewCompositeTextScanner a scanner over a text encoded composite value.
func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
if len(src) < 2 {
return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
}
if src[0] != '(' {
return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
}
if src[len(src)-1] != ')' {
return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
}
return &CompositeTextScanner{
ci: ci,
rp: 1,
src: src,
}
}
// ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// ScanDecoder calls Next and scans the result into d.
func (cfs *CompositeTextScanner) ScanValue(d interface{}) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Next() bool {
if cfs.err != nil {
return false
}
if cfs.rp == len(cfs.src) {
return false
}
switch cfs.src[cfs.rp] {
case ',', ')': // null
cfs.rp++
cfs.fieldBytes = nil
return true
case '"': // quoted value
cfs.rp++
cfs.fieldBytes = make([]byte, 0, 16)
for {
ch := cfs.src[cfs.rp]
if ch == '"' {
cfs.rp++
if cfs.src[cfs.rp] == '"' {
cfs.fieldBytes = append(cfs.fieldBytes, '"')
cfs.rp++
} else {
break
}
} else if ch == '\\' {
cfs.rp++
cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
cfs.rp++
} else {
cfs.fieldBytes = append(cfs.fieldBytes, ch)
cfs.rp++
}
}
cfs.rp++
return true
default: // unquoted value
start := cfs.rp
for {
ch := cfs.src[cfs.rp]
if ch == ',' || ch == ')' {
break
}
cfs.rp++
}
cfs.fieldBytes = cfs.src[start:cfs.rp]
cfs.rp++
return true
}
}
// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeTextScanner) Bytes() []byte {
return cfs.fieldBytes
}
// Err returns any error encountered by the scanner.
func (cfs *CompositeTextScanner) Err() error {
return cfs.err
}
type CompositeBinaryBuilder struct {
ci *ConnInfo
buf []byte
startIdx int
fieldCount uint32
err error
}
func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder {
startIdx := len(buf)
buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx}
}
func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
if b.err != nil {
return
}
dt, ok := b.ci.DataTypeForOID(oid)
if !ok {
b.err = fmt.Errorf("unknown data type for OID: %d", oid)
return
}
err := dt.Value.Set(field)
if err != nil {
b.err = err
return
}
paramEncoder, ok := dt.Value.(ParamEncoder)
if !ok {
b.err = fmt.Errorf("unable to encode for OID: %d", oid)
return
}
b.AppendEncoder(oid, paramEncoder)
}
func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) {
if b.err != nil {
return
}
b.buf = pgio.AppendUint32(b.buf, oid)
lengthPos := len(b.buf)
b.buf = pgio.AppendInt32(b.buf, -1)
fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf)
if err != nil {
b.err = err
return
}
if fieldBuf != nil {
binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
b.buf = fieldBuf
}
b.fieldCount++
}
func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
if b.err != nil {
return nil, b.err
}
binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
return b.buf, nil
}
type CompositeTextBuilder struct {
ci *ConnInfo
buf []byte
startIdx int
fieldCount uint32
err error
fieldBuf [32]byte
}
func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
buf = append(buf, '(') // allocate room for number of fields
return &CompositeTextBuilder{ci: ci, buf: buf}
}
func (b *CompositeTextBuilder) AppendValue(field interface{}) {
if b.err != nil {
return
}
if field == nil {
b.buf = append(b.buf, ',')
return
}
dt, ok := b.ci.DataTypeForValue(field)
if !ok {
b.err = fmt.Errorf("unknown data type for field: %v", field)
return
}
err := dt.Value.Set(field)
if err != nil {
b.err = err
return
}
paramEncoder, ok := dt.Value.(ParamEncoder)
if !ok {
b.err = fmt.Errorf("unable to encode for value: %v", field)
return
}
b.AppendEncoder(paramEncoder)
}
func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) {
if b.err != nil {
return
}
fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0])
if err != nil {
b.err = err
return
}
if fieldBuf != nil {
b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
}
b.buf = append(b.buf, ',')
}
func (b *CompositeTextBuilder) Finish() ([]byte, error) {
if b.err != nil {
return nil, b.err
}
b.buf[len(b.buf)-1] = ')'
return b.buf, nil
}
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
func quoteCompositeField(src string) string {
return `"` + quoteCompositeReplacer.Replace(src) + `"`
}
func quoteCompositeFieldIfNeeded(src string) string {
if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
return quoteCompositeField(src)
}
return src
}

View File

@ -1,320 +0,0 @@
package pgtype_test
import (
"context"
"fmt"
"os"
"testing"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCompositeTypeSetAndGet(t *testing.T) {
ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
assert.Equal(t, nil, ct.Get())
nilTests := []struct {
src interface{}
}{
{nil}, // nil interface
{(*[]interface{})(nil)}, // typed nil
}
for i, tt := range nilTests {
err := ct.Set(tt.src)
assert.NoErrorf(t, err, "%d", i)
assert.Equal(t, nil, ct.Get())
}
compatibleValuesTests := []struct {
src []interface{}
expected map[string]interface{}
}{
{
src: []interface{}{"foo", int32(42)},
expected: map[string]interface{}{"a": "foo", "b": int32(42)},
},
{
src: []interface{}{nil, nil},
expected: map[string]interface{}{"a": nil, "b": nil},
},
{
src: []interface{}{&pgtype.Text{String: "hi", Valid: true}, &pgtype.Int4{Int: 7, Valid: true}},
expected: map[string]interface{}{"a": "hi", "b": int32(7)},
},
}
for i, tt := range compatibleValuesTests {
err := ct.Set(tt.src)
assert.NoErrorf(t, err, "%d", i)
assert.EqualValues(t, tt.expected, ct.Get())
}
}
func TestCompositeTypeAssignTo(t *testing.T) {
ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var a string
var b int32
err = ct.AssignTo([]interface{}{&a, &b})
assert.NoError(t, err)
assert.Equal(t, "foo", a)
assert.Equal(t, int32(42), b)
}
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var a pgtype.Text
var b pgtype.Int4
err = ct.AssignTo([]interface{}{&a, &b})
assert.NoError(t, err)
assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a)
assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b)
}
// Allow nil destination component as no-op
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var b int32
err = ct.AssignTo([]interface{}{nil, &b})
assert.NoError(t, err)
assert.Equal(t, int32(42), b)
}
// *[]interface{} dest when null
{
err := ct.Set(nil)
assert.NoError(t, err)
var a pgtype.Text
var b pgtype.Int4
dst := []interface{}{&a, &b}
err = ct.AssignTo(&dst)
assert.NoError(t, err)
assert.Nil(t, dst)
}
// *[]interface{} dest when not null
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var a pgtype.Text
var b pgtype.Int4
dst := []interface{}{&a, &b}
err = ct.AssignTo(&dst)
assert.NoError(t, err)
assert.NotNil(t, dst)
assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a)
assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b)
}
// Struct fields positionally via reflection
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
s := struct {
A string
B int32
}{}
err = ct.AssignTo(&s)
if assert.NoError(t, err) {
assert.Equal(t, "foo", s.A)
assert.Equal(t, int32(42), s.B)
}
}
}
func TestCompositeTypeTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
create type ct_test as (
a text,
b int4
);`)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type ct_test")
var oid uint32
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type ct_test")
ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, conn.ConnInfo())
require.NoError(t, err)
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
// Use simple protocol to force text or binary encoding
simpleProtocols := []bool{true, false}
var a string
var b int32
for _, simpleProtocol := range simpleProtocols {
err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{"hi", int32(42)},
).Scan(
[]interface{}{&a, &b},
)
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol)
}
}
}
// https://github.com/jackc/pgx/issues/874
func TestCompositeTypeTextDecodeNested(t *testing.T) {
newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType {
fields := make([]pgtype.CompositeTypeField, len(fieldNames))
for i, name := range fieldNames {
fields[i] = pgtype.CompositeTypeField{Name: name}
}
rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals)
require.NoError(t, err)
return rowType
}
dimensionsType := func() pgtype.ValueTranscoder {
return newCompositeType(
"dimensions",
[]string{"width", "height"},
&pgtype.Int4{},
&pgtype.Int4{},
)
}
productImageType := func() pgtype.ValueTranscoder {
return newCompositeType(
"product_image_type",
[]string{"source", "dimensions"},
&pgtype.Text{},
dimensionsType(),
)
}
productImageSetType := newCompositeType(
"product_image_set_type",
[]string{"name", "orig_image", "images"},
&pgtype.Text{},
productImageType(),
pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder {
return productImageType()
}),
)
err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`))
require.NoError(t, err)
}
func Example_composite() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Println(err)
return
}
defer conn.Close(context.Background())
_, err = conn.Exec(context.Background(), `drop type if exists mytype;`)
if err != nil {
fmt.Println(err)
return
}
_, err = conn.Exec(context.Background(), `create type mytype as (
a int4,
b text
);`)
if err != nil {
fmt.Println(err)
return
}
defer conn.Exec(context.Background(), "drop type mytype")
var oid uint32
err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid)
if err != nil {
fmt.Println(err)
return
}
ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, conn.ConnInfo())
if err != nil {
fmt.Println(err)
return
}
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
var a int
var b *string
err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b})
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("First: a=%d b=%s\n", a, *b)
err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b})
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("Second: a=%d b=%v\n", a, b)
scanTarget := []interface{}{&a, &b}
err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget)
E(err)
fmt.Printf("Third: isNull=%v\n", scanTarget == nil)
// Output:
// First: a=2 b=bar
// Second: a=1 b=<nil>
// Third: isNull=true
}

View File

@ -1,87 +0,0 @@
package pgtype_test
import (
"context"
"errors"
"fmt"
"os"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
type MyType struct {
a int32 // NULL will cause decoding error
b *string // there can be NULL in this position in SQL
}
func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
}
if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil {
return err
}
return nil
}
func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) {
a := pgtype.Int4{src.a, true}
var b pgtype.Text
if src.b != nil {
b = pgtype.Text{*src.b, true}
} else {
b = pgtype.Text{}
}
return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf)
}
func ptrS(s string) *string {
return &s
}
func E(err error) {
if err != nil {
panic(err)
}
}
// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL
// composites can be added.
func Example_customCompositeTypes() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
E(err)
defer conn.Close(context.Background())
_, err = conn.Exec(context.Background(), `drop type if exists mytype;
create type mytype as (
a int4,
b text
);`)
E(err)
defer conn.Exec(context.Background(), "drop type mytype")
var result *MyType
// Demonstrates both passing and reading back composite values
err = conn.QueryRow(context.Background(), "select $1::mytype",
pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}).
Scan(&result)
E(err)
fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b)
// Because we scan into &*MyType, NULLs are handled generically by assigning nil to result
err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result)
E(err)
fmt.Printf("Second row: %v\n", result)
// Output:
// First row: a=1 b=foo
// Second row: <nil>
}

View File

@ -343,7 +343,7 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID})
ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID})
ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID})
ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
// ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID})
ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID})
ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID})

View File

@ -1,141 +0,0 @@
package pgtype
import (
"fmt"
"reflect"
)
// Record is the generic PostgreSQL record type such as is created with the
// "row" function. Record only implements BinaryEncoder and Value. The text
// format output format from PostgreSQL does not include type information and is
// therefore impossible to decode. No encoders are implemented because
// PostgreSQL does not support input of generic records.
type Record struct {
Fields []Value
Valid bool
}
func (dst *Record) Set(src interface{}) error {
if src == nil {
*dst = Record{}
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
switch value := src.(type) {
case []Value:
*dst = Record{Fields: value, Valid: true}
default:
return fmt.Errorf("cannot convert %v to Record", src)
}
return nil
}
func (dst Record) Get() interface{} {
if !dst.Valid {
return nil
}
return dst.Fields
}
func (src *Record) AssignTo(dst interface{}) error {
if !src.Valid {
return NullAssignTo(dst)
}
switch v := dst.(type) {
case *[]Value:
*v = make([]Value, len(src.Fields))
copy(*v, src.Fields)
return nil
case *[]interface{}:
*v = make([]interface{}, len(src.Fields))
for i := range *v {
(*v)[i] = src.Fields[i].Get()
}
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
return fmt.Errorf("unable to assign to %T", dst)
}
}
func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) {
var binaryDecoder BinaryDecoder
if dt, ok := ci.DataTypeForOID(fieldOID); ok {
binaryDecoder, _ = dt.Value.(BinaryDecoder)
} else {
return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID)
}
if binaryDecoder == nil {
return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID)
}
// Duplicate struct to scan into
binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder)
*v = binaryDecoder.(Value)
return binaryDecoder, nil
}
func (Record) BinaryFormatSupported() bool {
return true
}
func (Record) TextFormatSupported() bool {
return false
}
func (Record) PreferredFormat() int16 {
return BinaryFormatCode
}
func (dst *Record) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return fmt.Errorf("text format is not supported")
}
return fmt.Errorf("unknown format code %d", format)
}
func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Record{}
return nil
}
scanner := NewCompositeBinaryScanner(ci, src)
fields := make([]Value, scanner.FieldCount())
for i := 0; scanner.Next(); i++ {
binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i])
if err != nil {
return err
}
if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
}
if scanner.Err() != nil {
return scanner.Err()
}
*dst = Record{Fields: fields, Valid: true}
return nil
}

View File

@ -1,184 +0,0 @@
package pgtype_test
import (
"context"
"fmt"
"reflect"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
)
var recordTests = []struct {
sql string
expected pgtype.Record
}{
{
sql: `select row()`,
expected: pgtype.Record{
Fields: []pgtype.Value{},
Valid: true,
},
},
{
sql: `select row('foo'::text, 42::int4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Valid: true},
&pgtype.Int4{Int: 42, Valid: true},
},
Valid: true,
},
},
{
sql: `select row(100.0::float4, 1.09::float4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Float4{Float: 100, Valid: true},
&pgtype.Float4{Float: 1.09, Valid: true},
},
Valid: true,
},
},
{
sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Valid: true},
&pgtype.Int4Array{
Elements: []pgtype.Int4{
{Int: 1, Valid: true},
{Int: 2, Valid: true},
{},
{Int: 4, Valid: true},
},
Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}},
Valid: true,
},
&pgtype.Int4{Int: 42, Valid: true},
},
Valid: true,
},
},
{
sql: `select row(null)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Unknown{},
},
Valid: true,
},
},
{
sql: `select null::record`,
expected: pgtype.Record{},
},
}
func TestRecordTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
for i, tt := range recordTests {
psName := fmt.Sprintf("test%d", i)
_, err := conn.Prepare(context.Background(), psName, tt.sql)
if err != nil {
t.Fatal(err)
}
t.Run(tt.sql, func(t *testing.T) {
var result pgtype.Record
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil {
t.Errorf("%v", err)
return
}
if !reflect.DeepEqual(tt.expected, result) {
t.Errorf("expected %#v, got %#v", tt.expected, result)
}
})
}
}
func TestRecordWithUnknownOID(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Exec(context.Background(), `drop type if exists floatrange;
create type floatrange as range (
subtype = float8,
subtype_diff = float8mi
);`)
if err != nil {
t.Fatal(err)
}
defer conn.Exec(context.Background(), "drop type floatrange")
var result pgtype.Record
err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result)
if err == nil {
t.Errorf("expected error but none")
}
}
func TestRecordAssignTo(t *testing.T) {
var valueSlice []pgtype.Value
var interfaceSlice []interface{}
simpleTests := []struct {
src pgtype.Record
dst interface{}
expected interface{}
}{
{
src: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Valid: true},
&pgtype.Int4{Int: 42, Valid: true},
},
Valid: true,
},
dst: &valueSlice,
expected: []pgtype.Value{
&pgtype.Text{String: "foo", Valid: true},
&pgtype.Int4{Int: 42, Valid: true},
},
},
{
src: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Valid: true},
&pgtype.Int4{Int: 42, Valid: true},
},
Valid: true,
},
dst: &interfaceSlice,
expected: []interface{}{"foo", int32(42)},
},
{
src: pgtype.Record{},
dst: &valueSlice,
expected: (([]pgtype.Value)(nil)),
},
{
src: pgtype.Record{},
dst: &interfaceSlice,
expected: (([]interface{})(nil)),
},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}