Restore multi-dimensional slices

Move ArrayCode to use pgtype wrapper pattern as well
query-exec-mode
Jack Christensen 2022-02-08 10:07:40 -06:00
parent 318018504a
commit 7193e48923
7 changed files with 424 additions and 173 deletions

View File

@ -7,6 +7,7 @@ rule '.go' => '.go.erb' do |task|
end
generated_code_files = [
"pgtype/array_getter_setter.go",
"pgtype/int.go",
"pgtype/int_test.go",
"pgtype/integration_benchmark_test.go",

View File

@ -16,6 +16,9 @@ type ArrayGetter interface {
// Index returns the element at i.
Index(i int) interface{}
// IndexType returns a non-nil scan target of the type Index will return. This is used by ArrayCodec.PlanEncode.
IndexType() interface{}
}
// ArraySetter is a type can be set from a PostgreSQL array.
@ -27,6 +30,10 @@ type ArraySetter interface {
// ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex.
ScanIndex(i int) interface{}
// ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by
// ArrayCodec.PlanScan.
ScanIndexType() interface{}
}
// ArrayCodec is a codec for any array type.
@ -43,6 +50,18 @@ func (c *ArrayCodec) PreferredFormat() int16 {
}
func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
arrayValuer, ok := value.(ArrayGetter)
if !ok {
return nil
}
elementType := arrayValuer.IndexType()
elementEncodePlan := ci.PlanEncode(c.ElementDataType.OID, format, elementType)
if elementEncodePlan == nil {
return nil
}
switch format {
case BinaryFormatCode:
return &encodePlanArrayCodecBinary{ac: c, ci: ci, oid: oid}
@ -60,10 +79,7 @@ type encodePlanArrayCodecText struct {
}
func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
array, err := makeArrayGetter(value)
if err != nil {
return nil, err
}
array := value.(ArrayGetter)
dimensions := array.Dimensions()
if dimensions == nil {
@ -142,10 +158,7 @@ type encodePlanArrayCodecBinary struct {
}
func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
array, err := makeArrayGetter(value)
if err != nil {
return nil, err
}
array := value.(ArrayGetter)
dimensions := array.Dimensions()
if dimensions == nil {
@ -198,8 +211,15 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB
}
func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
_, err := makeArraySetter(target)
if err != nil {
arrayScanner, ok := target.(ArraySetter)
if !ok {
return nil
}
elementType := arrayScanner.ScanIndexType()
elementScanPlan := ci.PlanScan(c.ElementDataType.OID, format, elementType)
if _, ok := elementScanPlan.(*scanPlanFail); ok {
return nil
}
@ -300,10 +320,11 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array
}
type scanPlanArrayCodec struct {
arrayCodec *ArrayCodec
ci *ConnInfo
oid uint32
formatCode int16
arrayCodec *ArrayCodec
ci *ConnInfo
oid uint32
formatCode int16
elementScanPlan ScanPlan
}
func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error {
@ -312,11 +333,7 @@ func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error {
oid := spac.oid
formatCode := spac.formatCode
array, err := makeArraySetter(dst)
if err != nil {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(src, dst)
}
array := dst.(ArraySetter)
if src == nil {
return array.SetDimensions(nil)
@ -358,3 +375,26 @@ func (c *ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []b
err := ci.PlanScan(oid, format, &slice).Scan(src, &slice)
return slice, err
}
func isRagged(slice reflect.Value) bool {
if slice.Type().Elem().Kind() != reflect.Slice {
return false
}
sliceLen := slice.Len()
innerLen := 0
for i := 0; i < sliceLen; i++ {
if i == 0 {
innerLen = slice.Index(i).Len()
} else {
if slice.Index(i).Len() != innerLen {
return true
}
}
if isRagged(slice.Index(i)) {
return true
}
}
return false
}

View File

@ -108,3 +108,60 @@ func TestArrayCodecDecodeValue(t *testing.T) {
})
}
}
func TestArrayCodecScanMultipleDimensions(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`)
require.NoError(t, err)
for rows.Next() {
var ss [][]int32
err := rows.Scan(&ss)
require.NoError(t, err)
require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss)
}
require.NoError(t, rows.Err())
}
func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`)
require.NoError(t, err)
for rows.Next() {
var ss [][][]int32
err := rows.Scan(&ss)
require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions")
}
}
func TestArrayCodecEncodeMultipleDimensions(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}})
require.NoError(t, err)
for rows.Next() {
var ss [][]int32
err := rows.Scan(&ss)
require.NoError(t, err)
require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss)
}
require.NoError(t, rows.Err())
}
func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}})
require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional")
defer rows.Close()
}

View File

@ -1,10 +1,6 @@
// Do not edit. Generated from pgtype/array_getter_setter.go.erb
package pgtype
import (
"fmt"
"reflect"
)
type int16Array []int16
func (a int16Array) Dimensions() []ArrayDimension {
@ -19,6 +15,11 @@ func (a int16Array) Index(i int) interface{} {
return a[i]
}
func (a int16Array) IndexType() interface{} {
var el int16
return el
}
func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
@ -34,6 +35,10 @@ func (a int16Array) ScanIndex(i int) interface{} {
return &a[i]
}
func (a int16Array) ScanIndexType() interface{} {
return new(int16)
}
type uint16Array []uint16
func (a uint16Array) Dimensions() []ArrayDimension {
@ -48,6 +53,11 @@ func (a uint16Array) Index(i int) interface{} {
return a[i]
}
func (a uint16Array) IndexType() interface{} {
var el uint16
return el
}
func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
@ -63,81 +73,6 @@ func (a uint16Array) ScanIndex(i int) interface{} {
return &a[i]
}
type anySliceArray struct {
slice reflect.Value
}
func (a anySliceArray) Dimensions() []ArrayDimension {
if a.slice.IsNil() {
return nil
}
return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}}
}
func (a anySliceArray) Index(i int) interface{} {
return a.slice.Index(i).Interface()
}
func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error {
sliceType := a.slice.Type()
if dimensions == nil {
a.slice.Set(reflect.Zero(sliceType))
return nil
}
elementCount := cardinality(dimensions)
slice := reflect.MakeSlice(sliceType, elementCount, elementCount)
a.slice.Set(slice)
return nil
}
func (a anySliceArray) ScanIndex(i int) interface{} {
return a.slice.Index(i).Addr().Interface()
}
func makeArrayGetter(a interface{}) (ArrayGetter, error) {
switch a := a.(type) {
case ArrayGetter:
return a, nil
case []int16:
return (*int16Array)(&a), nil
case []uint16:
return (*uint16Array)(&a), nil
}
reflectValue := reflect.ValueOf(a)
if reflectValue.Kind() == reflect.Slice {
return &anySliceArray{slice: reflectValue}, nil
}
return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a)
}
func makeArraySetter(a interface{}) (ArraySetter, error) {
switch a := a.(type) {
case ArraySetter:
return a, nil
case *[]int16:
return (*int16Array)(a), nil
case *[]uint16:
return (*uint16Array)(a), nil
}
value := reflect.ValueOf(a)
if value.Kind() == reflect.Ptr {
elemValue := value.Elem()
if elemValue.Kind() == reflect.Slice {
return &anySliceArray{slice: elemValue}, nil
}
}
return nil, fmt.Errorf("cannot convert %T to ArraySetter", a)
func (a uint16Array) ScanIndexType() interface{} {
return new(uint16)
}

View File

@ -27,6 +27,11 @@ import (
return a[i]
}
func (a <%= array_type %>) IndexType() interface{} {
var el <%= element_type %>
return el
}
func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
@ -41,77 +46,8 @@ import (
func (a <%= array_type %>) ScanIndex(i int) interface{} {
return &a[i]
}
<% end %>
type anySliceArray struct {
slice reflect.Value
}
func (a anySliceArray) Dimensions() []ArrayDimension {
if a.slice.IsNil() {
return nil
}
return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}}
}
func (a anySliceArray) Index(i int) interface{} {
return a.slice.Index(i).Interface()
}
func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error {
sliceType := a.slice.Type()
if dimensions == nil {
a.slice.Set(reflect.Zero(sliceType))
return nil
}
elementCount := cardinality(dimensions)
slice := reflect.MakeSlice(sliceType, elementCount, elementCount)
a.slice.Set(slice)
return nil
}
func (a anySliceArray) ScanIndex(i int) interface{} {
return a.slice.Index(i).Addr().Interface()
}
func makeArrayGetter(a interface{}) (ArrayGetter, error) {
switch a := a.(type) {
case ArrayGetter:
return a, nil
<% types.each do |array_type, element_type| %>
case []<%= element_type %>:
return (*<%= array_type %>)(&a), nil
<% end %>
}
reflectValue := reflect.ValueOf(a)
if reflectValue.Kind() == reflect.Slice {
return &anySliceArray{slice: reflectValue}, nil
func (a <%= array_type %>) ScanIndexType() interface{} {
return new(<%= element_type %>)
}
return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a)
}
func makeArraySetter(a interface{}) (ArraySetter, error) {
switch a := a.(type) {
case ArraySetter:
return a, nil
<% types.each do |array_type, element_type| %>
case *[]<%= element_type %>:
return (*<%= array_type %>)(a), nil
<% end %>
}
value := reflect.ValueOf(a)
if value.Kind() == reflect.Ptr {
elemValue := value.Elem()
if elemValue.Kind() == reflect.Slice {
return &anySliceArray{slice: elemValue}, nil
}
}
return nil, fmt.Errorf("cannot convert %T to ArraySetter", a)
}
<% end %>

View File

@ -656,3 +656,168 @@ func (w *ptrStructWrapper) ScanIndex(i int) interface{} {
return w.exportedFields[i].Addr().Interface()
}
type anySliceArray struct {
slice reflect.Value
}
func (a anySliceArray) Dimensions() []ArrayDimension {
if a.slice.IsNil() {
return nil
}
return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}}
}
func (a anySliceArray) Index(i int) interface{} {
return a.slice.Index(i).Interface()
}
func (a anySliceArray) IndexType() interface{} {
return reflect.New(a.slice.Type().Elem()).Elem().Interface()
}
func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error {
sliceType := a.slice.Type()
if dimensions == nil {
a.slice.Set(reflect.Zero(sliceType))
return nil
}
elementCount := cardinality(dimensions)
slice := reflect.MakeSlice(sliceType, elementCount, elementCount)
a.slice.Set(slice)
return nil
}
func (a *anySliceArray) ScanIndex(i int) interface{} {
return a.slice.Index(i).Addr().Interface()
}
func (a *anySliceArray) ScanIndexType() interface{} {
return reflect.New(a.slice.Type().Elem()).Interface()
}
type anyMultiDimSliceArray struct {
slice reflect.Value
dims []ArrayDimension
}
func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension {
if a.slice.IsNil() {
return nil
}
s := a.slice
for {
a.dims = append(a.dims, ArrayDimension{Length: int32(s.Len()), LowerBound: 1})
if s.Len() > 0 {
s = s.Index(0)
} else {
break
}
if s.Type().Kind() == reflect.Slice {
} else {
break
}
}
return a.dims
}
func (a *anyMultiDimSliceArray) Index(i int) interface{} {
if len(a.dims) == 1 {
return a.slice.Index(i).Interface()
}
indexes := make([]int, len(a.dims))
for j := len(a.dims) - 1; j >= 0; j-- {
dimLen := int(a.dims[j].Length)
indexes[j] = i % dimLen
i = i / dimLen
}
v := a.slice
for _, si := range indexes {
v = v.Index(si)
}
return v.Interface()
}
func (a *anyMultiDimSliceArray) IndexType() interface{} {
lowestSliceType := a.slice.Type()
for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() {
}
return reflect.New(lowestSliceType.Elem()).Elem().Interface()
}
func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error {
sliceType := a.slice.Type()
if dimensions == nil {
a.slice.Set(reflect.Zero(sliceType))
return nil
}
switch len(dimensions) {
case 0:
return fmt.Errorf("impossible: non-nil dimensions but zero elements")
case 1:
elementCount := cardinality(dimensions)
slice := reflect.MakeSlice(sliceType, elementCount, elementCount)
a.slice.Set(slice)
return nil
default:
sliceDimensionCount := 1
lowestSliceType := sliceType
for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() {
sliceDimensionCount++
}
if sliceDimensionCount != len(dimensions) {
return fmt.Errorf("PostgreSQL array has %d dimensions but slice has %d dimensions", len(dimensions), sliceDimensionCount)
}
elementCount := cardinality(dimensions)
flatSlice := reflect.MakeSlice(lowestSliceType, elementCount, elementCount)
multiDimSlice := a.makeMultidimensionalSlice(sliceType, dimensions, flatSlice, 0)
a.slice.Set(multiDimSlice)
// Now that a.slice is a multi-dimensional slice with the underlying data pointed at flatSlice change a.slice to
// flatSlice so ScanIndex only has to handle simple one dimensional slices.
a.slice = flatSlice
return nil
}
}
func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value {
if len(dimensions) == 1 {
endIdx := flatSliceIdx + int(dimensions[0].Length)
return flatSlice.Slice3(flatSliceIdx, endIdx, endIdx)
}
sliceLen := int(dimensions[0].Length)
slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen)
for i := 0; i < sliceLen; i++ {
subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length)))
slice.Index(i).Set(subSlice)
}
return slice
}
func (a *anyMultiDimSliceArray) ScanIndex(i int) interface{} {
return a.slice.Index(i).Addr().Interface()
}
func (a *anyMultiDimSliceArray) ScanIndexType() interface{} {
lowestSliceType := a.slice.Type()
for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() {
}
return reflect.New(lowestSliceType.Elem()).Interface()
}

View File

@ -204,6 +204,8 @@ func NewConnInfo() *ConnInfo {
TryWrapBuiltinTypeEncodePlan,
TryWrapFindUnderlyingTypeEncodePlan,
TryWrapStructEncodePlan,
TryWrapSliceEncodePlan,
TryWrapMultiDimSliceEncodePlan,
},
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
@ -211,6 +213,8 @@ func NewConnInfo() *ConnInfo {
TryWrapBuiltinTypeScanPlan,
TryFindUnderlyingTypeScanPlan,
TryWrapStructScanPlan,
TryWrapPtrSliceScanPlan,
TryWrapPtrMultiDimSliceScanPlan,
},
}
@ -930,6 +934,62 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error
return plan.next.Scan(src, &w)
}
// TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice.
func TryWrapPtrSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return nil, nil, false
}
targetElemValue := targetValue.Elem()
if targetElemValue.Kind() == reflect.Slice {
return &wrapPtrSliceScanPlan{}, &anySliceArray{slice: targetElemValue}, true
}
return nil, nil, false
}
type wrapPtrSliceScanPlan struct {
next ScanPlan
}
func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target interface{}) error {
return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()})
}
// TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice.
func TryWrapPtrMultiDimSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return nil, nil, false
}
targetElemValue := targetValue.Elem()
if targetElemValue.Kind() == reflect.Slice {
elemElemKind := targetElemValue.Type().Elem().Kind()
if elemElemKind == reflect.Slice {
if !isRagged(targetElemValue) {
return &wrapPtrMultiDimSliceScanPlan{}, &anyMultiDimSliceArray{slice: targetValue.Elem()}, true
}
}
}
return nil, nil, false
}
type wrapPtrMultiDimSliceScanPlan struct {
next ScanPlan
}
func (plan *wrapPtrMultiDimSliceScanPlan) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target interface{}) error {
return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()})
}
// PlanScan prepares a plan to scan a value into target.
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan {
if _, ok := target.(*UndecodedBytes); ok {
@ -1495,6 +1555,63 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value {
return exportedFields
}
func TryWrapSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) {
if reflect.TypeOf(value).Kind() == reflect.Slice {
w := anySliceArray{
slice: reflect.ValueOf(value),
}
return &wrapSliceEncodePlan{}, w, true
}
return nil, nil, false
}
type wrapSliceEncodePlan struct {
next EncodePlan
}
func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
w := anySliceArray{
slice: reflect.ValueOf(value),
}
return plan.next.Encode(w, buf)
}
func TryWrapMultiDimSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) {
sliceValue := reflect.ValueOf(value)
if sliceValue.Kind() == reflect.Slice {
valueElemType := sliceValue.Type().Elem()
if valueElemType.Kind() == reflect.Slice {
if !isRagged(sliceValue) {
w := anyMultiDimSliceArray{
slice: reflect.ValueOf(value),
}
return &wrapMultiDimSliceEncodePlan{}, &w, true
}
}
}
return nil, nil, false
}
type wrapMultiDimSliceEncodePlan struct {
next EncodePlan
}
func (plan *wrapMultiDimSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapMultiDimSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
w := anyMultiDimSliceArray{
slice: reflect.ValueOf(value),
}
return plan.next.Encode(&w, buf)
}
// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
// written.