diff --git a/convert.go b/convert.go index 6d5ea0c9..45f117bc 100644 --- a/convert.go +++ b/convert.go @@ -434,14 +434,15 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -// ScanRowValue assigns ROW()'s fields to destination Values. -// Argument types are checked and error is returned if SQL field value -// can't be assigned to corresponding destionation Value without loss -// of information. Number of fields have to match number of destination values. +// ScanRowValue decodes ROW()'s and composite type +// from src argument using provided decoders. Decoders should match +// order and count of fields of record being decoded. +// +// In practice you can pass pgtype.Value types as decoders, as +// most of them implement BinaryDecoder interface. // -// Values must implement BinaryDecoder interface otherwise error is returned. // ScanRowValue takes ownership of src, caller MUST not use it after call -func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { +func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error { fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err @@ -457,12 +458,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { return err } - binaryDecoder, ok := dst[i].(BinaryDecoder) - if !ok { - return errors.Errorf("record field doesn't implement binary decoding: %s", reflect.TypeOf(dst[i]).Name()) - } - - if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + if err = dst[i].DecodeBinary(ci, fieldBytes); err != nil { return err } diff --git a/record_test.go b/record_test.go index af2105c7..9516612e 100644 --- a/record_test.go +++ b/record_test.go @@ -93,7 +93,10 @@ func TestScanRowValue(t *testing.T) { t.Fatal(err) } t.Run(tt.sql, func(t *testing.T) { - desc := append([]pgtype.Value(nil), tt.expected.Fields...) + desc := []pgtype.BinaryDecoder{} + for _, f := range tt.expected.Fields { + desc = append(desc, f.(pgtype.BinaryDecoder)) + } var raw pgtype.GenericBinary @@ -113,7 +116,10 @@ func TestScanRowValue(t *testing.T) { } // borrow fields from a neighbor test, this makes scan always fail - desc = append([]pgtype.Value(nil), recordTests[(i+1)%len(recordTests)].expected.Fields...) + desc = desc[:0] + for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields { + desc = append(desc, f.(pgtype.BinaryDecoder)) + } if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil { t.Error("Matching scan didn't fail, despite fields not mathching query result") }