Tighten ScanRowValue input types

ScanRowValue  needs not  Value, but BinaryEncoder
This commit is contained in:
Maxim Ivanov 2020-04-18 14:08:28 +01:00
parent 54a03cb143
commit b88a3e0765
2 changed files with 16 additions and 14 deletions

View File

@ -434,14 +434,15 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
return nil, false
}
// ScanRowValue assigns ROW()'s fields to destination Values.
// Argument types are checked and error is returned if SQL field value
// can't be assigned to corresponding destionation Value without loss
// of information. Number of fields have to match number of destination values.
// ScanRowValue decodes ROW()'s and composite type
// from src argument using provided decoders. Decoders should match
// order and count of fields of record being decoded.
//
// In practice you can pass pgtype.Value types as decoders, as
// most of them implement BinaryDecoder interface.
//
// Values must implement BinaryDecoder interface otherwise error is returned.
// ScanRowValue takes ownership of src, caller MUST not use it after call
func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error {
func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error {
fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src)
if err != nil {
return err
@ -457,12 +458,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error {
return err
}
binaryDecoder, ok := dst[i].(BinaryDecoder)
if !ok {
return errors.Errorf("record field doesn't implement binary decoding: %s", reflect.TypeOf(dst[i]).Name())
}
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil {
if err = dst[i].DecodeBinary(ci, fieldBytes); err != nil {
return err
}

View File

@ -93,7 +93,10 @@ func TestScanRowValue(t *testing.T) {
t.Fatal(err)
}
t.Run(tt.sql, func(t *testing.T) {
desc := append([]pgtype.Value(nil), tt.expected.Fields...)
desc := []pgtype.BinaryDecoder{}
for _, f := range tt.expected.Fields {
desc = append(desc, f.(pgtype.BinaryDecoder))
}
var raw pgtype.GenericBinary
@ -113,7 +116,10 @@ func TestScanRowValue(t *testing.T) {
}
// borrow fields from a neighbor test, this makes scan always fail
desc = append([]pgtype.Value(nil), recordTests[(i+1)%len(recordTests)].expected.Fields...)
desc = desc[:0]
for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields {
desc = append(desc, f.(pgtype.BinaryDecoder))
}
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil {
t.Error("Matching scan didn't fail, despite fields not mathching query result")
}