mirror of https://github.com/jackc/pgx.git
Add CompositeFields type
This adds support for the text format and removes the need for the ScanRowValue function.non-blocking
parent
1b3d694469
commit
36dbbd983d
90
composite.go
90
composite.go
|
@ -233,6 +233,96 @@ func (cfs *CompositeBinaryScanner) Err() error {
|
|||
return cfs.err
|
||||
}
|
||||
|
||||
type CompositeTextScanner struct {
|
||||
rp int
|
||||
src []byte
|
||||
|
||||
fieldBytes []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// NewCompositeTextScanner a scanner over a text encoded composite balue.
|
||||
func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) {
|
||||
if len(src) < 2 {
|
||||
return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src)
|
||||
}
|
||||
|
||||
if src[0] != '(' {
|
||||
return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('")
|
||||
}
|
||||
|
||||
if src[len(src)-1] != ')' {
|
||||
return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'")
|
||||
}
|
||||
|
||||
return CompositeTextScanner{
|
||||
rp: 1,
|
||||
src: src,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
||||
// Scan returns false, the Err method can be called to check if any errors occurred.
|
||||
func (cfs *CompositeTextScanner) Scan() bool {
|
||||
if cfs.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if cfs.rp == len(cfs.src) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch cfs.src[cfs.rp] {
|
||||
case ',', ')': // null
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = nil
|
||||
return true
|
||||
case '"': // quoted value
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = make([]byte, 0, 16)
|
||||
for {
|
||||
ch := cfs.src[cfs.rp]
|
||||
|
||||
if ch == '"' {
|
||||
cfs.rp++
|
||||
if cfs.src[cfs.rp] == '"' {
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, '"')
|
||||
cfs.rp++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, ch)
|
||||
cfs.rp++
|
||||
}
|
||||
}
|
||||
cfs.rp++
|
||||
return true
|
||||
default: // unquoted value
|
||||
start := cfs.rp
|
||||
for {
|
||||
ch := cfs.src[cfs.rp]
|
||||
if ch == ',' || ch == ')' {
|
||||
break
|
||||
}
|
||||
cfs.rp++
|
||||
}
|
||||
cfs.fieldBytes = cfs.src[start:cfs.rp]
|
||||
cfs.rp++
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Bytes returns the bytes of the field most recently read by Scan().
|
||||
func (cfs *CompositeTextScanner) Bytes() []byte {
|
||||
return cfs.fieldBytes
|
||||
}
|
||||
|
||||
// Err returns any error encountered by the scanner.
|
||||
func (cfs *CompositeTextScanner) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
// RecordStart adds record header to the buf
|
||||
func RecordStart(buf []byte, fieldCount int) []byte {
|
||||
return pgio.AppendUint32(buf, uint32(fieldCount))
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a
|
||||
// nullable value use a *CompositeFields. It will be set to nil in case of null.
|
||||
type CompositeFields []interface{}
|
||||
|
||||
func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||
if len(cf) == 0 {
|
||||
return errors.Errorf("cannot decode into empty CompositeFields")
|
||||
}
|
||||
|
||||
if src == nil {
|
||||
return errors.Errorf("cannot decode unexpected null into CompositeFields")
|
||||
}
|
||||
|
||||
scanner, err := NewCompositeBinaryScanner(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(cf) != scanner.FieldCount() {
|
||||
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount())
|
||||
}
|
||||
|
||||
for i := 0; scanner.Scan(); i++ {
|
||||
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if len(cf) == 0 {
|
||||
return errors.Errorf("cannot decode into empty CompositeFields")
|
||||
}
|
||||
|
||||
if src == nil {
|
||||
return errors.Errorf("cannot decode unexpected null into CompositeFields")
|
||||
}
|
||||
|
||||
scanner, err := NewCompositeTextScanner(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldCount := 0
|
||||
|
||||
for i := 0; scanner.Scan(); i++ {
|
||||
err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldCount += 1
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
if len(cf) != fieldCount {
|
||||
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgtype/testutil"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCompositeFieldsDecode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}
|
||||
|
||||
// Assorted values
|
||||
{
|
||||
var a int32
|
||||
var b string
|
||||
var c float64
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EqualValuesf(t, 1, a, "Format: %v", format)
|
||||
assert.EqualValuesf(t, "hi", b, "Format: %v", format)
|
||||
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// nulls, string "null", and empty string fields
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b string
|
||||
var c pgtype.Text
|
||||
var d string
|
||||
var e pgtype.Text
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Nilf(t, a.Get(), "Format: %v", format)
|
||||
assert.EqualValuesf(t, "null", b, "Format: %v", format)
|
||||
assert.Nilf(t, c.Get(), "Format: %v", format)
|
||||
assert.EqualValuesf(t, "", d, "Format: %v", format)
|
||||
assert.Nilf(t, e.Get(), "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// null record
|
||||
{
|
||||
var a pgtype.Text
|
||||
var b string
|
||||
cf := pgtype.CompositeFields{&a, &b}
|
||||
|
||||
for _, format := range formats {
|
||||
// Cannot scan nil into
|
||||
err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
|
||||
cf,
|
||||
)
|
||||
if assert.Errorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
assert.NotNilf(t, cf, "Format: %v", format)
|
||||
|
||||
// But can scan nil into *pgtype.CompositeFields
|
||||
err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
|
||||
&cf,
|
||||
)
|
||||
if assert.Errorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
assert.Nilf(t, cf, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// quotes and special characters
|
||||
{
|
||||
var a, b, c, d string
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b, &c, &d},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equalf(t, `"`, a, "Format: %v", format)
|
||||
assert.Equalf(t, `foo bar`, b, "Format: %v", format)
|
||||
assert.Equalf(t, `foo'bar`, c, "Format: %v", format)
|
||||
assert.Equalf(t, `baz)bar`, d, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
|
||||
// arrays
|
||||
{
|
||||
var a []string
|
||||
var b []int64
|
||||
|
||||
for _, format := range formats {
|
||||
err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan(
|
||||
pgtype.CompositeFields{&a, &b},
|
||||
)
|
||||
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format)
|
||||
assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
|
||||
}
|
||||
}
|
||||
}
|
32
convert.go
32
convert.go
|
@ -433,38 +433,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
// ScanRowValue decodes ROW()'s and composite type
|
||||
// from src argument using provided decoders. Decoders should match
|
||||
// order and count of fields of record being decoded.
|
||||
//
|
||||
// In practice you can pass pgtype.Value types as decoders, as
|
||||
// most of them implement BinaryDecoder interface.
|
||||
//
|
||||
// ScanRowValue takes ownership of src, caller MUST not use it after call
|
||||
func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error {
|
||||
scanner, err := NewCompositeBinaryScanner(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(dst) != scanner.FieldCount() {
|
||||
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst))
|
||||
}
|
||||
|
||||
for i := 0; scanner.Scan(); i++ {
|
||||
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeRow builds a binary representation of row values (row(), composite types)
|
||||
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
|
||||
fieldBytes := make([]byte, 0, 128)
|
||||
|
|
|
@ -20,7 +20,7 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
|||
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
|
||||
}
|
||||
|
||||
if err := pgtype.ScanRowValue(ci, src, &dst.a, &dst.b); err != nil {
|
||||
if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -79,54 +79,6 @@ var recordTests = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
// row values are binary compatible with records, so we test our helper
|
||||
// routines here
|
||||
func TestScanRowValue(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
for i := 0; i < len(recordTests); i++ {
|
||||
tt := recordTests[i]
|
||||
psName := fmt.Sprintf("test%d", i)
|
||||
_, err := conn.Prepare(context.Background(), psName, tt.sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Run(tt.sql, func(t *testing.T) {
|
||||
desc := []interface{}{}
|
||||
for _, f := range tt.expected.Fields {
|
||||
desc = append(desc, f.(pgtype.BinaryDecoder))
|
||||
}
|
||||
|
||||
var raw pgtype.GenericBinary
|
||||
|
||||
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if raw.Status == pgtype.Null {
|
||||
// ScanRowValue deals with complete rows only, NULL values (but NOT null fields)
|
||||
// should be handled by the calling code
|
||||
return
|
||||
}
|
||||
|
||||
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// borrow fields from a neighbor test, this makes scan always fail
|
||||
desc = desc[:0]
|
||||
for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields {
|
||||
desc = append(desc, f.(pgtype.BinaryDecoder))
|
||||
}
|
||||
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil {
|
||||
t.Error("Matching scan didn't fail, despite fields not mathching query result")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordTranscode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
|
Loading…
Reference in New Issue