mirror of https://github.com/jackc/pgx.git
Temporarily remove composite and record support
parent
ffa1fdd66e
commit
40fb889605
|
@ -1,192 +0,0 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MyCompositeRaw struct {
|
||||
A int32
|
||||
B *string
|
||||
}
|
||||
|
||||
func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
||||
buf = pgio.AppendUint32(buf, 2)
|
||||
|
||||
buf = pgio.AppendUint32(buf, pgtype.Int4OID)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
buf = pgio.AppendInt32(buf, src.A)
|
||||
|
||||
buf = pgio.AppendUint32(buf, pgtype.TextOID)
|
||||
if src.B != nil {
|
||||
buf = pgio.AppendInt32(buf, int32(len(*src.B)))
|
||||
buf = append(buf, (*src.B)...)
|
||||
} else {
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
||||
a := pgtype.Int4{}
|
||||
b := pgtype.Text{}
|
||||
|
||||
scanner := pgtype.NewCompositeBinaryScanner(ci, src)
|
||||
scanner.ScanDecoder(&a)
|
||||
scanner.ScanDecoder(&b)
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
dst.A = a.Int
|
||||
if b.Valid {
|
||||
dst.B = &b.String
|
||||
} else {
|
||||
dst.B = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var x []byte
|
||||
|
||||
func BenchmarkBinaryEncodingManual(b *testing.B) {
|
||||
buf := make([]byte, 0, 128)
|
||||
ci := pgtype.NewConnInfo()
|
||||
v := MyCompositeRaw{4, ptrS("ABCDEFG")}
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
buf, _ = v.EncodeBinary(ci, buf[:0])
|
||||
}
|
||||
x = buf
|
||||
}
|
||||
|
||||
func BenchmarkBinaryEncodingHelper(b *testing.B) {
|
||||
buf := make([]byte, 0, 128)
|
||||
ci := pgtype.NewConnInfo()
|
||||
v := MyType{4, ptrS("ABCDEFG")}
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
buf, _ = v.EncodeBinary(ci, buf[:0])
|
||||
}
|
||||
x = buf
|
||||
}
|
||||
|
||||
func BenchmarkBinaryEncodingComposite(b *testing.B) {
|
||||
buf := make([]byte, 0, 128)
|
||||
ci := pgtype.NewConnInfo()
|
||||
f1 := 2
|
||||
f2 := ptrS("bar")
|
||||
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.Int4OID},
|
||||
{"b", pgtype.TextOID},
|
||||
}, ci)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
c.Set([]interface{}{f1, f2})
|
||||
buf, _ = c.EncodeBinary(ci, buf[:0])
|
||||
}
|
||||
x = buf
|
||||
}
|
||||
|
||||
func BenchmarkBinaryEncodingJSON(b *testing.B) {
|
||||
buf := make([]byte, 0, 128)
|
||||
ci := pgtype.NewConnInfo()
|
||||
v := MyCompositeRaw{4, ptrS("ABCDEFG")}
|
||||
j := pgtype.JSON{}
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
j.Set(v)
|
||||
buf, _ = j.EncodeBinary(ci, buf[:0])
|
||||
}
|
||||
x = buf
|
||||
}
|
||||
|
||||
var dstRaw MyCompositeRaw
|
||||
|
||||
func BenchmarkBinaryDecodingManual(b *testing.B) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
|
||||
dst := MyCompositeRaw{}
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := dst.DecodeBinary(ci, buf)
|
||||
E(err)
|
||||
}
|
||||
dstRaw = dst
|
||||
}
|
||||
|
||||
var dstMyType MyType
|
||||
|
||||
func BenchmarkBinaryDecodingHelpers(b *testing.B) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
|
||||
dst := MyType{}
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := dst.DecodeBinary(ci, buf)
|
||||
E(err)
|
||||
}
|
||||
dstMyType = dst
|
||||
}
|
||||
|
||||
var gf1 int
|
||||
var gf2 *string
|
||||
|
||||
func BenchmarkBinaryDecodingCompositeScan(b *testing.B) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
|
||||
var f1 int
|
||||
var f2 *string
|
||||
|
||||
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.Int4OID},
|
||||
{"b", pgtype.TextOID},
|
||||
}, ci)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := c.DecodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
err = c.AssignTo([]interface{}{&f1, &f2})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
gf1 = f1
|
||||
gf2 = f2
|
||||
}
|
||||
|
||||
func BenchmarkBinaryDecodingJSON(b *testing.B) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
j := pgtype.JSON{}
|
||||
j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")})
|
||||
buf, _ := j.EncodeBinary(ci, nil)
|
||||
|
||||
j = pgtype.JSON{}
|
||||
dst := MyCompositeRaw{}
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := j.DecodeBinary(ci, buf)
|
||||
E(err)
|
||||
err = j.AssignTo(&dst)
|
||||
E(err)
|
||||
}
|
||||
dstRaw = dst
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import "fmt"
|
||||
|
||||
// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a
|
||||
// nullable value use a *CompositeFields. It will be set to nil in case of null.
|
||||
//
|
||||
// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not
|
||||
// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType.
|
||||
type CompositeFields []interface{}
|
||||
|
||||
func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if len(cf) == 0 {
|
||||
return fmt.Errorf("cannot decode into empty CompositeFields")
|
||||
}
|
||||
|
||||
if src == nil {
|
||||
return fmt.Errorf("cannot decode unexpected null into CompositeFields")
|
||||
}
|
||||
|
||||
scanner := NewCompositeBinaryScanner(ci, src)
|
||||
|
||||
for _, f := range cf {
|
||||
scanner.ScanValue(f)
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if len(cf) == 0 {
|
||||
return fmt.Errorf("cannot decode into empty CompositeFields")
|
||||
}
|
||||
|
||||
if src == nil {
|
||||
return fmt.Errorf("cannot decode unexpected null into CompositeFields")
|
||||
}
|
||||
|
||||
scanner := NewCompositeTextScanner(ci, src)
|
||||
|
||||
for _, f := range cf {
|
||||
scanner.ScanValue(f)
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using
|
||||
// CompositeFields to encode directly.
|
||||
func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
b := NewCompositeTextBuilder(ci, buf)
|
||||
|
||||
for _, f := range cf {
|
||||
if paramEncoder, ok := f.(ParamEncoder); ok {
|
||||
b.AppendEncoder(paramEncoder)
|
||||
} else {
|
||||
b.AppendValue(f)
|
||||
}
|
||||
}
|
||||
|
||||
return b.Finish()
|
||||
}
|
||||
|
||||
// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is
|
||||
// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary
|
||||
// composite format requires the OID of each field to be specified the only types that will work are those known to
|
||||
// ConnInfo.
|
||||
//
|
||||
// In particular:
|
||||
//
|
||||
// * Nil cannot be used because there is no way to determine what type it.
|
||||
// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail.
|
||||
// * No dereferencing will be done. e.g. *Text must be used instead of Text.
|
||||
func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
b := NewCompositeBinaryBuilder(ci, buf)
|
||||
|
||||
for _, f := range cf {
|
||||
dt, ok := ci.DataTypeForValue(f)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unknown OID for %#v", f)
|
||||
}
|
||||
|
||||
if paramEncoder, ok := f.(ParamEncoder); ok {
|
||||
b.AppendEncoder(dt.OID, paramEncoder)
|
||||
} else {
|
||||
err := dt.Value.Set(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paramEncoder, ok := dt.Value.(ParamEncoder); ok {
|
||||
b.AppendEncoder(dt.OID, paramEncoder)
|
||||
} else {
|
||||
return nil, fmt.Errorf("Cannot encode binary format for %v", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.Finish()
|
||||
}
|
|
@ -1,273 +0,0 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgtype/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeFieldsDecode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}
|
||||
|
||||
// Assorted values
|
||||
{
|
||||
var a int32
|
||||
var b string
|
||||
var c float64
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EqualValuesf(t, 1, a, "Format: %v", format)
|
||||
assert.EqualValuesf(t, "hi", b, "Format: %v", format)
|
||||
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// nulls, string "null", and empty string fields
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b string
|
||||
var c pgtype.Text
|
||||
var d string
|
||||
var e pgtype.Text
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Nilf(t, a.Get(), "Format: %v", format)
|
||||
assert.EqualValuesf(t, "null", b, "Format: %v", format)
|
||||
assert.Nilf(t, c.Get(), "Format: %v", format)
|
||||
assert.EqualValuesf(t, "", d, "Format: %v", format)
|
||||
assert.Nilf(t, e.Get(), "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// null record
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b string
|
||||
cf := pgtype.CompositeFields{&a, &b}
|
||||
|
||||
for _, format := range formats {
|
||||
// Cannot scan nil into
|
||||
err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
|
||||
cf,
|
||||
)
|
||||
if assert.Errorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
assert.NotNilf(t, cf, "Format: %v", format)
|
||||
|
||||
// But can scan nil into *pgtype.CompositeFields
|
||||
err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
|
||||
&cf,
|
||||
)
|
||||
if assert.Errorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
assert.Nilf(t, cf, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// quotes and special characters
|
||||
{
|
||||
var a, b, c, d string
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equalf(t, `"`, a, "Format: %v", format)
|
||||
assert.Equalf(t, `foo bar`, b, "Format: %v", format)
|
||||
assert.Equalf(t, `foo'bar`, c, "Format: %v", format)
|
||||
assert.Equalf(t, `baz)bar`, d, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// arrays
|
||||
{
|
||||
var a []string
|
||||
var b []int64
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format)
|
||||
assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip nil fields
|
||||
{
|
||||
var a int32
|
||||
var c float64
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, nil, &c},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EqualValuesf(t, 1, a, "Format: %v", format)
|
||||
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeFieldsEncode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists cf_encode;
|
||||
|
||||
create type cf_encode as (
|
||||
a text,
|
||||
b int4,
|
||||
c text,
|
||||
d float8,
|
||||
e text
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type cf_encode")
|
||||
|
||||
// Use simple protocol to force text or binary encoding
|
||||
simpleProtocols := []bool{true, false}
|
||||
|
||||
// Assorted values
|
||||
{
|
||||
var a string
|
||||
var b int32
|
||||
var c string
|
||||
var d float64
|
||||
var e string
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// untyped nil
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b int32
|
||||
var c string
|
||||
var d pgtype.Float8
|
||||
var e pgtype.Text
|
||||
|
||||
simpleProtocol := true
|
||||
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
|
||||
// untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema
|
||||
// of the composite type.
|
||||
simpleProtocol = false
|
||||
err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
|
||||
// nulls, string "null", and empty string fields
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b int32
|
||||
var c string
|
||||
var d pgtype.Float8
|
||||
var e pgtype.Text
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{&pgtype.Text{}, int32(1), "null", &pgtype.Float8{}, &pgtype.Text{}},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// quotes and special characters
|
||||
{
|
||||
var a string
|
||||
var b int32
|
||||
var c string
|
||||
var d float64
|
||||
var e string
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
`select $1::cf_encode`,
|
||||
pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,715 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CompositeTypeField struct {
|
||||
Name string
|
||||
OID uint32
|
||||
}
|
||||
|
||||
type CompositeType struct {
|
||||
valid bool
|
||||
|
||||
typeName string
|
||||
|
||||
fields []CompositeTypeField
|
||||
valueTranscoders []ValueTranscoder
|
||||
}
|
||||
|
||||
// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used
|
||||
// for fields. All field OIDs must be previously registered in ci.
|
||||
func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
|
||||
valueTranscoders := make([]ValueTranscoder, len(fields))
|
||||
|
||||
for i := range fields {
|
||||
dt, ok := ci.DataTypeForOID(fields[i].OID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID)
|
||||
}
|
||||
|
||||
value := NewValue(dt.Value)
|
||||
valueTranscoder, ok := value.(ValueTranscoder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
|
||||
}
|
||||
|
||||
valueTranscoders[i] = valueTranscoder
|
||||
}
|
||||
|
||||
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
|
||||
}
|
||||
|
||||
// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length.
|
||||
// Prefer NewCompositeType unless overriding the transcoding of fields is required.
|
||||
func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
|
||||
if len(fields) != len(values) {
|
||||
return nil, errors.New("fields and valueTranscoders must have same length")
|
||||
}
|
||||
|
||||
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
|
||||
}
|
||||
|
||||
func (src CompositeType) Get() interface{} {
|
||||
if !src.valid {
|
||||
return nil
|
||||
}
|
||||
|
||||
results := make(map[string]interface{}, len(src.valueTranscoders))
|
||||
for i := range src.valueTranscoders {
|
||||
results[src.fields[i].Name] = src.valueTranscoders[i].Get()
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (ct *CompositeType) NewTypeValue() Value {
|
||||
a := &CompositeType{
|
||||
typeName: ct.typeName,
|
||||
fields: ct.fields,
|
||||
valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
|
||||
}
|
||||
|
||||
for i := range ct.valueTranscoders {
|
||||
a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (ct *CompositeType) TypeName() string {
|
||||
return ct.typeName
|
||||
}
|
||||
|
||||
func (ct *CompositeType) Fields() []CompositeTypeField {
|
||||
return ct.fields
|
||||
}
|
||||
|
||||
func (dst *CompositeType) setNil() {
|
||||
dst.valid = false
|
||||
}
|
||||
|
||||
func (dst *CompositeType) Set(src interface{}) error {
|
||||
if src == nil {
|
||||
dst.setNil()
|
||||
return nil
|
||||
}
|
||||
|
||||
switch value := src.(type) {
|
||||
case []interface{}:
|
||||
if len(value) != len(dst.valueTranscoders) {
|
||||
return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
|
||||
}
|
||||
for i, v := range value {
|
||||
if err := dst.valueTranscoders[i].Set(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
dst.valid = true
|
||||
case *[]interface{}:
|
||||
if value == nil {
|
||||
dst.setNil()
|
||||
return nil
|
||||
}
|
||||
return dst.Set(*value)
|
||||
default:
|
||||
return fmt.Errorf("Can not convert %v to Composite", src)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignTo should never be called on composite value directly
|
||||
func (src CompositeType) AssignTo(dst interface{}) error {
|
||||
if !src.valid {
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
switch v := dst.(type) {
|
||||
case []interface{}:
|
||||
if len(v) != len(src.valueTranscoders) {
|
||||
return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
|
||||
}
|
||||
for i := range src.valueTranscoders {
|
||||
if v[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err := assignToOrSet(src.valueTranscoders[i], v[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to assign to dst[%d]: %v", i, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *[]interface{}:
|
||||
return src.AssignTo(*v)
|
||||
default:
|
||||
if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct {
|
||||
return err
|
||||
}
|
||||
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
return fmt.Errorf("unable to assign to %T", dst)
|
||||
}
|
||||
}
|
||||
|
||||
func assignToOrSet(src Value, dst interface{}) error {
|
||||
assignToErr := src.AssignTo(dst)
|
||||
if assignToErr != nil {
|
||||
// Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self.
|
||||
setSucceeded := false
|
||||
if setter, ok := dst.(Value); ok {
|
||||
err := setter.Set(src.Get())
|
||||
setSucceeded = err == nil
|
||||
}
|
||||
if !setSucceeded {
|
||||
return assignToErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Ptr {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if dstValue.IsNil() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dstElemValue := dstValue.Elem()
|
||||
dstElemType := dstElemValue.Type()
|
||||
|
||||
if dstElemType.Kind() != reflect.Struct {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
exportedFields := make([]int, 0, dstElemType.NumField())
|
||||
for i := 0; i < dstElemType.NumField(); i++ {
|
||||
sf := dstElemType.Field(i)
|
||||
if sf.PkgPath == "" {
|
||||
exportedFields = append(exportedFields, i)
|
||||
}
|
||||
}
|
||||
|
||||
if len(exportedFields) != len(src.valueTranscoders) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i := range exportedFields {
|
||||
err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (ct *CompositeType) BinaryFormatSupported() bool {
|
||||
for _, vt := range ct.valueTranscoders {
|
||||
if !vt.BinaryFormatSupported() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ct *CompositeType) TextFormatSupported() bool {
|
||||
for _, vt := range ct.valueTranscoders {
|
||||
if !vt.TextFormatSupported() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ct *CompositeType) PreferredFormat() int16 {
|
||||
if ct.BinaryFormatSupported() {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
|
||||
if src == nil {
|
||||
dst.setNil()
|
||||
return nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return dst.DecodeBinary(ci, src)
|
||||
case TextFormatCode:
|
||||
return dst.DecodeText(ci, src)
|
||||
}
|
||||
return fmt.Errorf("unknown format code %d", format)
|
||||
}
|
||||
|
||||
func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return src.EncodeBinary(ci, buf)
|
||||
case TextFormatCode:
|
||||
return src.EncodeText(ci, buf)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown format code %d", format)
|
||||
}
|
||||
|
||||
func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
|
||||
if !src.valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b := NewCompositeBinaryBuilder(ci, buf)
|
||||
for i := range src.valueTranscoders {
|
||||
b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
|
||||
}
|
||||
|
||||
return b.Finish()
|
||||
}
|
||||
|
||||
// DecodeBinary implements BinaryDecoder interface.
|
||||
// Opposite to Record, fields in a composite act as a "schema"
|
||||
// and decoding fails if SQL value can't be assigned due to
|
||||
// type mismatch
|
||||
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
|
||||
scanner := NewCompositeBinaryScanner(ci, buf)
|
||||
|
||||
for _, f := range dst.valueTranscoders {
|
||||
scanner.ScanDecoder(f)
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
dst.valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
|
||||
scanner := NewCompositeTextScanner(ci, buf)
|
||||
|
||||
for _, f := range dst.valueTranscoders {
|
||||
scanner.ScanDecoder(f)
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
dst.valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
|
||||
if !src.valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b := NewCompositeTextBuilder(ci, buf)
|
||||
for _, f := range src.valueTranscoders {
|
||||
b.AppendEncoder(f)
|
||||
}
|
||||
|
||||
return b.Finish()
|
||||
}
|
||||
|
||||
type CompositeBinaryScanner struct {
|
||||
ci *ConnInfo
|
||||
rp int
|
||||
src []byte
|
||||
|
||||
fieldCount int32
|
||||
fieldBytes []byte
|
||||
fieldOID uint32
|
||||
err error
|
||||
}
|
||||
|
||||
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
|
||||
func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
|
||||
rp := 0
|
||||
if len(src[rp:]) < 4 {
|
||||
return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
|
||||
}
|
||||
|
||||
fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
return &CompositeBinaryScanner{
|
||||
ci: ci,
|
||||
rp: rp,
|
||||
src: src,
|
||||
fieldCount: fieldCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanDecoder calls Next and decodes the result with d.
|
||||
func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) {
|
||||
if cfs.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cfs.Next() {
|
||||
cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes)
|
||||
} else {
|
||||
cfs.err = errors.New("read past end of composite")
|
||||
}
|
||||
}
|
||||
|
||||
// ScanDecoder calls Next and scans the result into d.
|
||||
func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) {
|
||||
if cfs.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cfs.Next() {
|
||||
cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d)
|
||||
} else {
|
||||
cfs.err = errors.New("read past end of composite")
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
||||
// Next returns false, the Err method can be called to check if any errors occurred.
|
||||
func (cfs *CompositeBinaryScanner) Next() bool {
|
||||
if cfs.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if cfs.rp == len(cfs.src) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(cfs.src[cfs.rp:]) < 8 {
|
||||
cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
|
||||
return false
|
||||
}
|
||||
cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
|
||||
cfs.rp += 4
|
||||
|
||||
fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
|
||||
cfs.rp += 4
|
||||
|
||||
if fieldLen >= 0 {
|
||||
if len(cfs.src[cfs.rp:]) < fieldLen {
|
||||
cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
|
||||
return false
|
||||
}
|
||||
cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
|
||||
cfs.rp += fieldLen
|
||||
} else {
|
||||
cfs.fieldBytes = nil
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfs *CompositeBinaryScanner) FieldCount() int {
|
||||
return int(cfs.fieldCount)
|
||||
}
|
||||
|
||||
// Bytes returns the bytes of the field most recently read by Scan().
|
||||
func (cfs *CompositeBinaryScanner) Bytes() []byte {
|
||||
return cfs.fieldBytes
|
||||
}
|
||||
|
||||
// OID returns the OID of the field most recently read by Scan().
|
||||
func (cfs *CompositeBinaryScanner) OID() uint32 {
|
||||
return cfs.fieldOID
|
||||
}
|
||||
|
||||
// Err returns any error encountered by the scanner.
|
||||
func (cfs *CompositeBinaryScanner) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
type CompositeTextScanner struct {
|
||||
ci *ConnInfo
|
||||
rp int
|
||||
src []byte
|
||||
|
||||
fieldBytes []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// NewCompositeTextScanner a scanner over a text encoded composite value.
|
||||
func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
|
||||
if len(src) < 2 {
|
||||
return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
|
||||
}
|
||||
|
||||
if src[0] != '(' {
|
||||
return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
|
||||
}
|
||||
|
||||
if src[len(src)-1] != ')' {
|
||||
return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
|
||||
}
|
||||
|
||||
return &CompositeTextScanner{
|
||||
ci: ci,
|
||||
rp: 1,
|
||||
src: src,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanDecoder calls Next and decodes the result with d.
|
||||
func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) {
|
||||
if cfs.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cfs.Next() {
|
||||
cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes)
|
||||
} else {
|
||||
cfs.err = errors.New("read past end of composite")
|
||||
}
|
||||
}
|
||||
|
||||
// ScanDecoder calls Next and scans the result into d.
|
||||
func (cfs *CompositeTextScanner) ScanValue(d interface{}) {
|
||||
if cfs.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cfs.Next() {
|
||||
cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d)
|
||||
} else {
|
||||
cfs.err = errors.New("read past end of composite")
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
||||
// Next returns false, the Err method can be called to check if any errors occurred.
|
||||
func (cfs *CompositeTextScanner) Next() bool {
|
||||
if cfs.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if cfs.rp == len(cfs.src) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch cfs.src[cfs.rp] {
|
||||
case ',', ')': // null
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = nil
|
||||
return true
|
||||
case '"': // quoted value
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = make([]byte, 0, 16)
|
||||
for {
|
||||
ch := cfs.src[cfs.rp]
|
||||
|
||||
if ch == '"' {
|
||||
cfs.rp++
|
||||
if cfs.src[cfs.rp] == '"' {
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, '"')
|
||||
cfs.rp++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else if ch == '\\' {
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
|
||||
cfs.rp++
|
||||
} else {
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, ch)
|
||||
cfs.rp++
|
||||
}
|
||||
}
|
||||
cfs.rp++
|
||||
return true
|
||||
default: // unquoted value
|
||||
start := cfs.rp
|
||||
for {
|
||||
ch := cfs.src[cfs.rp]
|
||||
if ch == ',' || ch == ')' {
|
||||
break
|
||||
}
|
||||
cfs.rp++
|
||||
}
|
||||
cfs.fieldBytes = cfs.src[start:cfs.rp]
|
||||
cfs.rp++
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Bytes returns the bytes of the field most recently read by Scan().
|
||||
func (cfs *CompositeTextScanner) Bytes() []byte {
|
||||
return cfs.fieldBytes
|
||||
}
|
||||
|
||||
// Err returns any error encountered by the scanner.
|
||||
func (cfs *CompositeTextScanner) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
type CompositeBinaryBuilder struct {
|
||||
ci *ConnInfo
|
||||
buf []byte
|
||||
startIdx int
|
||||
fieldCount uint32
|
||||
err error
|
||||
}
|
||||
|
||||
func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder {
|
||||
startIdx := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
|
||||
return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx}
|
||||
}
|
||||
|
||||
func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dt, ok := b.ci.DataTypeForOID(oid)
|
||||
if !ok {
|
||||
b.err = fmt.Errorf("unknown data type for OID: %d", oid)
|
||||
return
|
||||
}
|
||||
|
||||
err := dt.Value.Set(field)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
|
||||
paramEncoder, ok := dt.Value.(ParamEncoder)
|
||||
if !ok {
|
||||
b.err = fmt.Errorf("unable to encode for OID: %d", oid)
|
||||
return
|
||||
}
|
||||
|
||||
b.AppendEncoder(oid, paramEncoder)
|
||||
}
|
||||
|
||||
func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.buf = pgio.AppendUint32(b.buf, oid)
|
||||
lengthPos := len(b.buf)
|
||||
b.buf = pgio.AppendInt32(b.buf, -1)
|
||||
fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
|
||||
b.buf = fieldBuf
|
||||
}
|
||||
|
||||
b.fieldCount++
|
||||
}
|
||||
|
||||
func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
type CompositeTextBuilder struct {
|
||||
ci *ConnInfo
|
||||
buf []byte
|
||||
startIdx int
|
||||
fieldCount uint32
|
||||
err error
|
||||
fieldBuf [32]byte
|
||||
}
|
||||
|
||||
func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
|
||||
buf = append(buf, '(') // allocate room for number of fields
|
||||
return &CompositeTextBuilder{ci: ci, buf: buf}
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) AppendValue(field interface{}) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
b.buf = append(b.buf, ',')
|
||||
return
|
||||
}
|
||||
|
||||
dt, ok := b.ci.DataTypeForValue(field)
|
||||
if !ok {
|
||||
b.err = fmt.Errorf("unknown data type for field: %v", field)
|
||||
return
|
||||
}
|
||||
|
||||
err := dt.Value.Set(field)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
|
||||
paramEncoder, ok := dt.Value.(ParamEncoder)
|
||||
if !ok {
|
||||
b.err = fmt.Errorf("unable to encode for value: %v", field)
|
||||
return
|
||||
}
|
||||
|
||||
b.AppendEncoder(paramEncoder)
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0])
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
|
||||
b.buf = append(b.buf, ',')
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) Finish() ([]byte, error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
b.buf[len(b.buf)-1] = ')'
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
||||
|
||||
func quoteCompositeField(src string) string {
|
||||
return `"` + quoteCompositeReplacer.Replace(src) + `"`
|
||||
}
|
||||
|
||||
func quoteCompositeFieldIfNeeded(src string) string {
|
||||
if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
|
||||
return quoteCompositeField(src)
|
||||
}
|
||||
return src
|
||||
}
|
|
@ -1,320 +0,0 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgtype/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeTypeSetAndGet(t *testing.T) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.TextOID},
|
||||
{"b", pgtype.Int4OID},
|
||||
}, ci)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nil, ct.Get())
|
||||
|
||||
nilTests := []struct {
|
||||
src interface{}
|
||||
}{
|
||||
{nil}, // nil interface
|
||||
{(*[]interface{})(nil)}, // typed nil
|
||||
}
|
||||
|
||||
for i, tt := range nilTests {
|
||||
err := ct.Set(tt.src)
|
||||
assert.NoErrorf(t, err, "%d", i)
|
||||
assert.Equal(t, nil, ct.Get())
|
||||
}
|
||||
|
||||
compatibleValuesTests := []struct {
|
||||
src []interface{}
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
src: []interface{}{"foo", int32(42)},
|
||||
expected: map[string]interface{}{"a": "foo", "b": int32(42)},
|
||||
},
|
||||
{
|
||||
src: []interface{}{nil, nil},
|
||||
expected: map[string]interface{}{"a": nil, "b": nil},
|
||||
},
|
||||
{
|
||||
src: []interface{}{&pgtype.Text{String: "hi", Valid: true}, &pgtype.Int4{Int: 7, Valid: true}},
|
||||
expected: map[string]interface{}{"a": "hi", "b": int32(7)},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range compatibleValuesTests {
|
||||
err := ct.Set(tt.src)
|
||||
assert.NoErrorf(t, err, "%d", i)
|
||||
assert.EqualValues(t, tt.expected, ct.Get())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTypeAssignTo(t *testing.T) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.TextOID},
|
||||
{"b", pgtype.Int4OID},
|
||||
}, ci)
|
||||
require.NoError(t, err)
|
||||
|
||||
{
|
||||
err := ct.Set([]interface{}{"foo", int32(42)})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var a string
|
||||
var b int32
|
||||
|
||||
err = ct.AssignTo([]interface{}{&a, &b})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "foo", a)
|
||||
assert.Equal(t, int32(42), b)
|
||||
}
|
||||
|
||||
{
|
||||
err := ct.Set([]interface{}{"foo", int32(42)})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var a pgtype.Text
|
||||
var b pgtype.Int4
|
||||
|
||||
err = ct.AssignTo([]interface{}{&a, &b})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a)
|
||||
assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b)
|
||||
}
|
||||
|
||||
// Allow nil destination component as no-op
|
||||
{
|
||||
err := ct.Set([]interface{}{"foo", int32(42)})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var b int32
|
||||
|
||||
err = ct.AssignTo([]interface{}{nil, &b})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(42), b)
|
||||
}
|
||||
|
||||
// *[]interface{} dest when null
|
||||
{
|
||||
err := ct.Set(nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var a pgtype.Text
|
||||
var b pgtype.Int4
|
||||
dst := []interface{}{&a, &b}
|
||||
|
||||
err = ct.AssignTo(&dst)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Nil(t, dst)
|
||||
}
|
||||
|
||||
// *[]interface{} dest when not null
|
||||
{
|
||||
err := ct.Set([]interface{}{"foo", int32(42)})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var a pgtype.Text
|
||||
var b pgtype.Int4
|
||||
dst := []interface{}{&a, &b}
|
||||
|
||||
err = ct.AssignTo(&dst)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, dst)
|
||||
assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a)
|
||||
assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b)
|
||||
}
|
||||
|
||||
// Struct fields positionally via reflection
|
||||
{
|
||||
err := ct.Set([]interface{}{"foo", int32(42)})
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := struct {
|
||||
A string
|
||||
B int32
|
||||
}{}
|
||||
|
||||
err = ct.AssignTo(&s)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "foo", s.A)
|
||||
assert.Equal(t, int32(42), s.B)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTypeTranscode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
|
||||
|
||||
create type ct_test as (
|
||||
a text,
|
||||
b int4
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
var oid uint32
|
||||
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.TextOID},
|
||||
{"b", pgtype.Int4OID},
|
||||
}, conn.ConnInfo())
|
||||
require.NoError(t, err)
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
|
||||
|
||||
// Use simple protocol to force text or binary encoding
|
||||
simpleProtocols := []bool{true, false}
|
||||
|
||||
var a string
|
||||
var b int32
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{"hi", int32(42)},
|
||||
).Scan(
|
||||
[]interface{}{&a, &b},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/874
|
||||
func TestCompositeTypeTextDecodeNested(t *testing.T) {
|
||||
newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType {
|
||||
fields := make([]pgtype.CompositeTypeField, len(fieldNames))
|
||||
for i, name := range fieldNames {
|
||||
fields[i] = pgtype.CompositeTypeField{Name: name}
|
||||
}
|
||||
|
||||
rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals)
|
||||
require.NoError(t, err)
|
||||
return rowType
|
||||
}
|
||||
|
||||
dimensionsType := func() pgtype.ValueTranscoder {
|
||||
return newCompositeType(
|
||||
"dimensions",
|
||||
[]string{"width", "height"},
|
||||
&pgtype.Int4{},
|
||||
&pgtype.Int4{},
|
||||
)
|
||||
}
|
||||
productImageType := func() pgtype.ValueTranscoder {
|
||||
return newCompositeType(
|
||||
"product_image_type",
|
||||
[]string{"source", "dimensions"},
|
||||
&pgtype.Text{},
|
||||
dimensionsType(),
|
||||
)
|
||||
}
|
||||
productImageSetType := newCompositeType(
|
||||
"product_image_set_type",
|
||||
[]string{"name", "orig_image", "images"},
|
||||
&pgtype.Text{},
|
||||
productImageType(),
|
||||
pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder {
|
||||
return productImageType()
|
||||
}),
|
||||
)
|
||||
|
||||
err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Example_composite() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer conn.Close(context.Background())
|
||||
_, err = conn.Exec(context.Background(), `drop type if exists mytype;`)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Exec(context.Background(), `create type mytype as (
|
||||
a int4,
|
||||
b text
|
||||
);`)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer conn.Exec(context.Background(), "drop type mytype")
|
||||
|
||||
var oid uint32
|
||||
err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.Int4OID},
|
||||
{"b", pgtype.TextOID},
|
||||
}, conn.ConnInfo())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
|
||||
|
||||
var a int
|
||||
var b *string
|
||||
|
||||
err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("First: a=%d b=%s\n", a, *b)
|
||||
|
||||
err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Second: a=%d b=%v\n", a, b)
|
||||
|
||||
scanTarget := []interface{}{&a, &b}
|
||||
err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget)
|
||||
E(err)
|
||||
|
||||
fmt.Printf("Third: isNull=%v\n", scanTarget == nil)
|
||||
|
||||
// Output:
|
||||
// First: a=2 b=bar
|
||||
// Second: a=1 b=<nil>
|
||||
// Third: isNull=true
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type MyType struct {
|
||||
a int32 // NULL will cause decoding error
|
||||
b *string // there can be NULL in this position in SQL
|
||||
}
|
||||
|
||||
func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
|
||||
}
|
||||
|
||||
if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) {
|
||||
a := pgtype.Int4{src.a, true}
|
||||
var b pgtype.Text
|
||||
if src.b != nil {
|
||||
b = pgtype.Text{*src.b, true}
|
||||
} else {
|
||||
b = pgtype.Text{}
|
||||
}
|
||||
|
||||
return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf)
|
||||
}
|
||||
|
||||
func ptrS(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func E(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL
|
||||
// composites can be added.
|
||||
func Example_customCompositeTypes() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
E(err)
|
||||
|
||||
defer conn.Close(context.Background())
|
||||
_, err = conn.Exec(context.Background(), `drop type if exists mytype;
|
||||
|
||||
create type mytype as (
|
||||
a int4,
|
||||
b text
|
||||
);`)
|
||||
E(err)
|
||||
defer conn.Exec(context.Background(), "drop type mytype")
|
||||
|
||||
var result *MyType
|
||||
|
||||
// Demonstrates both passing and reading back composite values
|
||||
err = conn.QueryRow(context.Background(), "select $1::mytype",
|
||||
pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}).
|
||||
Scan(&result)
|
||||
E(err)
|
||||
|
||||
fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b)
|
||||
|
||||
// Because we scan into &*MyType, NULLs are handled generically by assigning nil to result
|
||||
err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result)
|
||||
E(err)
|
||||
|
||||
fmt.Printf("Second row: %v\n", result)
|
||||
|
||||
// Output:
|
||||
// First row: a=1 b=foo
|
||||
// Second row: <nil>
|
||||
}
|
|
@ -343,7 +343,7 @@ func NewConnInfo() *ConnInfo {
|
|||
ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID})
|
||||
ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID})
|
||||
ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID})
|
||||
ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
|
||||
// ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
|
||||
ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID})
|
||||
ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID})
|
||||
ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID})
|
||||
|
|
141
pgtype/record.go
141
pgtype/record.go
|
@ -1,141 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Record is the generic PostgreSQL record type such as is created with the
|
||||
// "row" function. Record only implements BinaryEncoder and Value. The text
|
||||
// format output format from PostgreSQL does not include type information and is
|
||||
// therefore impossible to decode. No encoders are implemented because
|
||||
// PostgreSQL does not support input of generic records.
|
||||
type Record struct {
|
||||
Fields []Value
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (dst *Record) Set(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Record{}
|
||||
return nil
|
||||
}
|
||||
|
||||
if value, ok := src.(interface{ Get() interface{} }); ok {
|
||||
value2 := value.Get()
|
||||
if value2 != value {
|
||||
return dst.Set(value2)
|
||||
}
|
||||
}
|
||||
|
||||
switch value := src.(type) {
|
||||
case []Value:
|
||||
*dst = Record{Fields: value, Valid: true}
|
||||
default:
|
||||
return fmt.Errorf("cannot convert %v to Record", src)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst Record) Get() interface{} {
|
||||
if !dst.Valid {
|
||||
return nil
|
||||
}
|
||||
return dst.Fields
|
||||
}
|
||||
|
||||
func (src *Record) AssignTo(dst interface{}) error {
|
||||
if !src.Valid {
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
switch v := dst.(type) {
|
||||
case *[]Value:
|
||||
*v = make([]Value, len(src.Fields))
|
||||
copy(*v, src.Fields)
|
||||
return nil
|
||||
case *[]interface{}:
|
||||
*v = make([]interface{}, len(src.Fields))
|
||||
for i := range *v {
|
||||
(*v)[i] = src.Fields[i].Get()
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
return fmt.Errorf("unable to assign to %T", dst)
|
||||
}
|
||||
}
|
||||
|
||||
func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) {
|
||||
var binaryDecoder BinaryDecoder
|
||||
|
||||
if dt, ok := ci.DataTypeForOID(fieldOID); ok {
|
||||
binaryDecoder, _ = dt.Value.(BinaryDecoder)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID)
|
||||
}
|
||||
|
||||
if binaryDecoder == nil {
|
||||
return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID)
|
||||
}
|
||||
|
||||
// Duplicate struct to scan into
|
||||
binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder)
|
||||
*v = binaryDecoder.(Value)
|
||||
return binaryDecoder, nil
|
||||
}
|
||||
|
||||
func (Record) BinaryFormatSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (Record) TextFormatSupported() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (Record) PreferredFormat() int16 {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
|
||||
func (dst *Record) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return dst.DecodeBinary(ci, src)
|
||||
case TextFormatCode:
|
||||
return fmt.Errorf("text format is not supported")
|
||||
}
|
||||
return fmt.Errorf("unknown format code %d", format)
|
||||
}
|
||||
|
||||
func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Record{}
|
||||
return nil
|
||||
}
|
||||
|
||||
scanner := NewCompositeBinaryScanner(ci, src)
|
||||
|
||||
fields := make([]Value, scanner.FieldCount())
|
||||
|
||||
for i := 0; scanner.Next(); i++ {
|
||||
binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
*dst = Record{Fields: fields, Valid: true}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgtype/testutil"
|
||||
)
|
||||
|
||||
var recordTests = []struct {
|
||||
sql string
|
||||
expected pgtype.Record
|
||||
}{
|
||||
{
|
||||
sql: `select row()`,
|
||||
expected: pgtype.Record{
|
||||
Fields: []pgtype.Value{},
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: `select row('foo'::text, 42::int4)`,
|
||||
expected: pgtype.Record{
|
||||
Fields: []pgtype.Value{
|
||||
&pgtype.Text{String: "foo", Valid: true},
|
||||
&pgtype.Int4{Int: 42, Valid: true},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: `select row(100.0::float4, 1.09::float4)`,
|
||||
expected: pgtype.Record{
|
||||
Fields: []pgtype.Value{
|
||||
&pgtype.Float4{Float: 100, Valid: true},
|
||||
&pgtype.Float4{Float: 1.09, Valid: true},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`,
|
||||
expected: pgtype.Record{
|
||||
Fields: []pgtype.Value{
|
||||
&pgtype.Text{String: "foo", Valid: true},
|
||||
&pgtype.Int4Array{
|
||||
Elements: []pgtype.Int4{
|
||||
{Int: 1, Valid: true},
|
||||
{Int: 2, Valid: true},
|
||||
{},
|
||||
{Int: 4, Valid: true},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}},
|
||||
Valid: true,
|
||||
},
|
||||
&pgtype.Int4{Int: 42, Valid: true},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: `select row(null)`,
|
||||
expected: pgtype.Record{
|
||||
Fields: []pgtype.Value{
|
||||
&pgtype.Unknown{},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: `select null::record`,
|
||||
expected: pgtype.Record{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestRecordTranscode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
for i, tt := range recordTests {
|
||||
psName := fmt.Sprintf("test%d", i)
|
||||
_, err := conn.Prepare(context.Background(), psName, tt.sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run(tt.sql, func(t *testing.T) {
|
||||
var result pgtype.Record
|
||||
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil {
|
||||
t.Errorf("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.expected, result) {
|
||||
t.Errorf("expected %#v, got %#v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordWithUnknownOID(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists floatrange;
|
||||
|
||||
create type floatrange as range (
|
||||
subtype = float8,
|
||||
subtype_diff = float8mi
|
||||
);`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Exec(context.Background(), "drop type floatrange")
|
||||
|
||||
var result pgtype.Record
|
||||
err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAssignTo(t *testing.T) {
|
||||
var valueSlice []pgtype.Value
|
||||
var interfaceSlice []interface{}
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.Record
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.Record{
|
||||
Fields: []pgtype.Value{
|
||||
&pgtype.Text{String: "foo", Valid: true},
|
||||
&pgtype.Int4{Int: 42, Valid: true},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
dst: &valueSlice,
|
||||
expected: []pgtype.Value{
|
||||
&pgtype.Text{String: "foo", Valid: true},
|
||||
&pgtype.Int4{Int: 42, Valid: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
src: pgtype.Record{
|
||||
Fields: []pgtype.Value{
|
||||
&pgtype.Text{String: "foo", Valid: true},
|
||||
&pgtype.Int4{Int: 42, Valid: true},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
dst: &interfaceSlice,
|
||||
expected: []interface{}{"foo", int32(42)},
|
||||
},
|
||||
{
|
||||
src: pgtype.Record{},
|
||||
dst: &valueSlice,
|
||||
expected: (([]pgtype.Value)(nil)),
|
||||
},
|
||||
{
|
||||
src: pgtype.Record{},
|
||||
dst: &interfaceSlice,
|
||||
expected: (([]interface{})(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue