Add CompositeFields type

This adds support for the text format and removes the need for the
ScanRowValue function.
non-blocking
Jack Christensen 2020-05-11 17:21:21 -05:00
parent 1b3d694469
commit 36dbbd983d
6 changed files with 293 additions and 81 deletions

View File

@ -233,6 +233,96 @@ func (cfs *CompositeBinaryScanner) Err() error {
return cfs.err
}
type CompositeTextScanner struct {
rp int
src []byte
fieldBytes []byte
err error
}
// NewCompositeTextScanner a scanner over a text encoded composite balue.
func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) {
if len(src) < 2 {
return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src)
}
if src[0] != '(' {
return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('")
}
if src[len(src)-1] != ')' {
return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'")
}
return CompositeTextScanner{
rp: 1,
src: src,
}, nil
}
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Scan returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Scan() bool {
if cfs.err != nil {
return false
}
if cfs.rp == len(cfs.src) {
return false
}
switch cfs.src[cfs.rp] {
case ',', ')': // null
cfs.rp++
cfs.fieldBytes = nil
return true
case '"': // quoted value
cfs.rp++
cfs.fieldBytes = make([]byte, 0, 16)
for {
ch := cfs.src[cfs.rp]
if ch == '"' {
cfs.rp++
if cfs.src[cfs.rp] == '"' {
cfs.fieldBytes = append(cfs.fieldBytes, '"')
cfs.rp++
} else {
break
}
} else {
cfs.fieldBytes = append(cfs.fieldBytes, ch)
cfs.rp++
}
}
cfs.rp++
return true
default: // unquoted value
start := cfs.rp
for {
ch := cfs.src[cfs.rp]
if ch == ',' || ch == ')' {
break
}
cfs.rp++
}
cfs.fieldBytes = cfs.src[start:cfs.rp]
cfs.rp++
return true
}
}
// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeTextScanner) Bytes() []byte {
return cfs.fieldBytes
}
// Err returns any error encountered by the scanner.
func (cfs *CompositeTextScanner) Err() error {
return cfs.err
}
// RecordStart adds record header to the buf
func RecordStart(buf []byte, fieldCount int) []byte {
return pgio.AppendUint32(buf, uint32(fieldCount))

76
composite_fields.go Normal file
View File

@ -0,0 +1,76 @@
package pgtype
import (
errors "golang.org/x/xerrors"
)
// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a
// nullable value use a *CompositeFields. It will be set to nil in case of null.
type CompositeFields []interface{}
func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
if len(cf) == 0 {
return errors.Errorf("cannot decode into empty CompositeFields")
}
if src == nil {
return errors.Errorf("cannot decode unexpected null into CompositeFields")
}
scanner, err := NewCompositeBinaryScanner(src)
if err != nil {
return err
}
if len(cf) != scanner.FieldCount() {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount())
}
for i := 0; scanner.Scan(); i++ {
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i])
if err != nil {
return err
}
}
if scanner.Err() != nil {
return scanner.Err()
}
return nil
}
func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
if len(cf) == 0 {
return errors.Errorf("cannot decode into empty CompositeFields")
}
if src == nil {
return errors.Errorf("cannot decode unexpected null into CompositeFields")
}
scanner, err := NewCompositeTextScanner(src)
if err != nil {
return err
}
fieldCount := 0
for i := 0; scanner.Scan(); i++ {
err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i])
if err != nil {
return err
}
fieldCount += 1
}
if scanner.Err() != nil {
return scanner.Err()
}
if len(cf) != fieldCount {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount)
}
return nil
}

126
composite_fields_test.go Normal file
View File

@ -0,0 +1,126 @@
package pgtype_test
import (
"context"
"testing"
"github.com/jackc/pgtype"
"github.com/jackc/pgtype/testutil"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
)
func TestCompositeFieldsDecode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}
// Assorted values
{
var a int32
var b string
var c float64
for _, format := range formats {
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b, &c},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.EqualValuesf(t, 1, a, "Format: %v", format)
assert.EqualValuesf(t, "hi", b, "Format: %v", format)
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
}
}
// nulls, string "null", and empty string fields
{
var a pgtype.Text
var b string
var c pgtype.Text
var d string
var e pgtype.Text
for _, format := range formats {
err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b, &c, &d, &e},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.Nilf(t, a.Get(), "Format: %v", format)
assert.EqualValuesf(t, "null", b, "Format: %v", format)
assert.Nilf(t, c.Get(), "Format: %v", format)
assert.EqualValuesf(t, "", d, "Format: %v", format)
assert.Nilf(t, e.Get(), "Format: %v", format)
}
}
// null record
{
var a pgtype.Text
var b string
cf := pgtype.CompositeFields{&a, &b}
for _, format := range formats {
// Cannot scan nil into
err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
cf,
)
if assert.Errorf(t, err, "Format: %v", format) {
continue
}
assert.NotNilf(t, cf, "Format: %v", format)
// But can scan nil into *pgtype.CompositeFields
err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
&cf,
)
if assert.Errorf(t, err, "Format: %v", format) {
continue
}
assert.Nilf(t, cf, "Format: %v", format)
}
}
// quotes and special characters
{
var a, b, c, d string
for _, format := range formats {
err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b, &c, &d},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.Equalf(t, `"`, a, "Format: %v", format)
assert.Equalf(t, `foo bar`, b, "Format: %v", format)
assert.Equalf(t, `foo'bar`, c, "Format: %v", format)
assert.Equalf(t, `baz)bar`, d, "Format: %v", format)
}
}
// arrays
{
var a []string
var b []int64
for _, format := range formats {
err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan(
pgtype.CompositeFields{&a, &b},
)
if !assert.NoErrorf(t, err, "Format: %v", format) {
continue
}
assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format)
assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
}
}
}

View File

@ -433,38 +433,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
return nil, false
}
// ScanRowValue decodes ROW()'s and composite type
// from src argument using provided decoders. Decoders should match
// order and count of fields of record being decoded.
//
// In practice you can pass pgtype.Value types as decoders, as
// most of them implement BinaryDecoder interface.
//
// ScanRowValue takes ownership of src, caller MUST not use it after call
func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error {
scanner, err := NewCompositeBinaryScanner(src)
if err != nil {
return err
}
if len(dst) != scanner.FieldCount() {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst))
}
for i := 0; scanner.Scan(); i++ {
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i])
if err != nil {
return err
}
}
if scanner.Err() != nil {
return scanner.Err()
}
return nil
}
// EncodeRow builds a binary representation of row values (row(), composite types)
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
fieldBytes := make([]byte, 0, 128)

View File

@ -20,7 +20,7 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
}
if err := pgtype.ScanRowValue(ci, src, &dst.a, &dst.b); err != nil {
if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil {
return err
}

View File

@ -79,54 +79,6 @@ var recordTests = []struct {
},
}
// row values are binary compatible with records, so we test our helper
// routines here
func TestScanRowValue(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
for i := 0; i < len(recordTests); i++ {
tt := recordTests[i]
psName := fmt.Sprintf("test%d", i)
_, err := conn.Prepare(context.Background(), psName, tt.sql)
if err != nil {
t.Fatal(err)
}
t.Run(tt.sql, func(t *testing.T) {
desc := []interface{}{}
for _, f := range tt.expected.Fields {
desc = append(desc, f.(pgtype.BinaryDecoder))
}
var raw pgtype.GenericBinary
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil {
t.Error(err)
return
}
if raw.Status == pgtype.Null {
// ScanRowValue deals with complete rows only, NULL values (but NOT null fields)
// should be handled by the calling code
return
}
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil {
t.Error(err)
}
// borrow fields from a neighbor test, this makes scan always fail
desc = desc[:0]
for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields {
desc = append(desc, f.(pgtype.BinaryDecoder))
}
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil {
t.Error("Matching scan didn't fail, despite fields not mathching query result")
}
})
}
}
func TestRecordTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)