From ccb207cba5b4dd520b9cc40994eeb161ce140737 Mon Sep 17 00:00:00 2001 From: WGH Date: Mon, 28 Mar 2022 05:06:16 +0300 Subject: [PATCH] Add support for record array Like Record itself, it only implements BinaryDecoder, doesn't implement BinaryEncoder, and has no support for the text protocol. --- record_array.go | 318 +++++++++++++++++++++++++++++++++++++++++++ record_array_test.go | 104 ++++++++++++++ typed_array_gen.sh | 2 + 3 files changed, 424 insertions(+) create mode 100644 record_array.go create mode 100644 record_array_test.go diff --git a/record_array.go b/record_array.go new file mode 100644 index 00000000..2271717a --- /dev/null +++ b/record_array.go @@ -0,0 +1,318 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "encoding/binary" + "fmt" + "reflect" +) + +type RecordArray struct { + Elements []Record + Dimensions []ArrayDimension + Status Status +} + +func (dst *RecordArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = RecordArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case [][]Value: + if value == nil { + *dst = RecordArray{Status: Null} + } else if len(value) == 0 { + *dst = RecordArray{Status: Present} + } else { + elements := make([]Record, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = RecordArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Record: + if value == nil { + *dst = RecordArray{Status: Null} + } else if len(value) == 0 { + *dst = RecordArray{Status: Present} + } else { + *dst = RecordArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = RecordArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for RecordArray", src) + } + if elementsLength == 0 { + *dst = RecordArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to RecordArray", src) + } + + *dst = RecordArray{ + Elements: make([]Record, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Record, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to RecordArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *RecordArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to RecordArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in RecordArray", err) + } + index++ + + return index, nil +} + +func (dst RecordArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *RecordArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][]Value: + *v = make([][]Value, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *RecordArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from RecordArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from RecordArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *RecordArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = RecordArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = RecordArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Record, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = RecordArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} diff --git a/record_array_test.go b/record_array_test.go new file mode 100644 index 00000000..9c92e333 --- /dev/null +++ b/record_array_test.go @@ -0,0 +1,104 @@ +package pgtype_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +var recordArrayTests = []struct { + sql string + expected pgtype.RecordArray +}{ + { + sql: `select array_agg((x::int4, x+100::int8)) from generate_series(0, 1) x;`, + expected: pgtype.RecordArray{ + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + }, + Elements: []pgtype.Record{ + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 100, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: 101, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + }, + }, +} + +func TestRecordArrayTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range recordArrayTests { + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.sql) + require.NoError(t, err) + + t.Run(tt.sql, func(t *testing.T) { + var result pgtype.RecordArray + err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + require.NoError(t, err) + + require.Equal(t, tt.expected, result) + }) + + } +} + +func TestRecordArrayAssignTo(t *testing.T) { + src := pgtype.RecordArray{ + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + }, + Elements: []pgtype.Record{ + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 100, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + { + Fields: []pgtype.Value{ + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: 101, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + Status: pgtype.Present, + } + dst := [][]pgtype.Value{} + err := src.AssignTo(&dst) + require.NoError(t, err) + + expected := [][]pgtype.Value{ + { + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 100, Status: pgtype.Present}, + }, + { + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: 101, Status: pgtype.Present}, + }, + } + require.Equal(t, expected, dst) +} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index a9090cd9..d922f1cb 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -25,4 +25,6 @@ erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]stri # While the binary format is theoretically possible it is only practical to use the text format. erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=RecordArray pgtype_element_type=Record go_array_types=[][]Value element_type_name=record text_null=NULL encode_binary=false text_format=false typed_array.go.erb > record_array.go + goimports -w *_array.go