Add Array and FlatArray container types

non-blocking
Jack Christensen 2022-04-16 11:28:37 -05:00
parent d4abe83edb
commit f1a4ae3070
8 changed files with 214 additions and 153 deletions

View File

@ -7,7 +7,6 @@ rule '.go' => '.go.erb' do |task|
end
generated_code_files = [
"pgtype/array_getter_setter.go",
"pgtype/int.go",
"pgtype/int_test.go",
"pgtype/integration_benchmark_test.go",

View File

@ -394,3 +394,88 @@ func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, e
}
return dimensions, elementsLength, true
}
// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves
// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed.
type Array[T any] struct {
Elements []T
Dims []ArrayDimension
Valid bool
}
func (a Array[T]) Dimensions() []ArrayDimension {
return a.Dims
}
func (a Array[T]) Index(i int) any {
return a.Elements[i]
}
func (a Array[T]) IndexType() any {
var el T
return el
}
func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
*a = Array[T]{}
return nil
}
elementCount := cardinality(dimensions)
*a = Array[T]{
Elements: make([]T, elementCount),
Dims: dimensions,
Valid: true,
}
return nil
}
func (a Array[T]) ScanIndex(i int) any {
return &a.Elements[i]
}
func (a Array[T]) ScanIndexType() any {
return new(T)
}
// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions
// and custom lower bounds. Use Array to preserve these.
type FlatArray[T any] []T
func (a FlatArray[T]) Dimensions() []ArrayDimension {
if a == nil {
return nil
}
return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
}
func (a FlatArray[T]) Index(i int) any {
return a[i]
}
func (a FlatArray[T]) IndexType() any {
var el T
return el
}
func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
return nil
}
elementCount := cardinality(dimensions)
*a = make(FlatArray[T], elementCount)
return nil
}
func (a FlatArray[T]) ScanIndex(i int) any {
return &a[i]
}
func (a FlatArray[T]) ScanIndexType() any {
return new(T)
}

View File

@ -23,9 +23,9 @@ type ArrayGetter interface {
// ArraySetter is a type can be set from a PostgreSQL array.
type ArraySetter interface {
// SetDimensions prepares the value such that ScanIndex can be called for each element. dimensions may be nil to
// indicate a NULL array. If unable to exactly preserve dimensions SetDimensions may return an error or silently
// flatten the array dimensions.
// SetDimensions prepares the value such that ScanIndex can be called for each element. This will remove any existing
// elements. dimensions may be nil to indicate a NULL array. If unable to exactly preserve dimensions SetDimensions
// may return an error or silently flatten the array dimensions.
SetDimensions(dimensions []ArrayDimension) error
// ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex.

View File

@ -5,6 +5,7 @@ import (
"testing"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -47,6 +48,53 @@ func TestArrayCodec(t *testing.T) {
})
}
func TestArrayCodecFlatArray(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
for i, tt := range []struct {
expected any
}{
{pgtype.FlatArray[int32](nil)},
{pgtype.FlatArray[int32]{}},
{pgtype.FlatArray[int32]{1, 2, 3}},
} {
var actual pgtype.FlatArray[int32]
err := conn.QueryRow(
ctx,
"select $1::int[]",
tt.expected,
).Scan(&actual)
assert.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expected, actual, "%d", i)
}
})
}
func TestArrayCodecArray(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
for i, tt := range []struct {
expected any
}{
{pgtype.Array[int32]{
Elements: []int32{1, 2, 3, 4},
Dims: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 2},
{Length: 2, LowerBound: 2},
},
Valid: true,
}},
} {
var actual pgtype.Array[int32]
err := conn.QueryRow(
ctx,
"select $1::int[]",
tt.expected,
).Scan(&actual)
assert.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expected, actual, "%d", i)
}
})
}
func TestArrayCodecAnySlice(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
type _int16Slice []int16

View File

@ -1,78 +0,0 @@
// Do not edit. Generated from pgtype/array_getter_setter.go.erb
package pgtype
type int16Array []int16
func (a int16Array) Dimensions() []ArrayDimension {
if a == nil {
return nil
}
return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
}
func (a int16Array) Index(i int) any {
return a[i]
}
func (a int16Array) IndexType() any {
var el int16
return el
}
func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
return nil
}
elementCount := cardinality(dimensions)
*a = make(int16Array, elementCount)
return nil
}
func (a int16Array) ScanIndex(i int) any {
return &a[i]
}
func (a int16Array) ScanIndexType() any {
return new(int16)
}
type uint16Array []uint16
func (a uint16Array) Dimensions() []ArrayDimension {
if a == nil {
return nil
}
return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
}
func (a uint16Array) Index(i int) any {
return a[i]
}
func (a uint16Array) IndexType() any {
var el uint16
return el
}
func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
return nil
}
elementCount := cardinality(dimensions)
*a = make(uint16Array, elementCount)
return nil
}
func (a uint16Array) ScanIndex(i int) any {
return &a[i]
}
func (a uint16Array) ScanIndexType() any {
return new(uint16)
}

View File

@ -1,53 +0,0 @@
package pgtype
import (
"fmt"
"reflect"
)
<%
types = [
["int16Array", "int16"],
["uint16Array", "uint16"],
]
%>
<% types.each do |array_type, element_type| %>
type <%= array_type %> []<%= element_type %>
func (a <%= array_type %>) Dimensions() []ArrayDimension {
if a == nil {
return nil
}
return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
}
func (a <%= array_type %>) Index(i int) any {
return a[i]
}
func (a <%= array_type %>) IndexType() any {
var el <%= element_type %>
return el
}
func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
return nil
}
elementCount := cardinality(dimensions)
*a = make(<%= array_type %>, elementCount)
return nil
}
func (a <%= array_type %>) ScanIndex(i int) any {
return &a[i]
}
func (a <%= array_type %>) ScanIndexType() any {
return new(<%= element_type %>)
}
<% end %>

View File

@ -637,11 +637,11 @@ func (w *ptrStructWrapper) ScanIndex(i int) any {
return w.exportedFields[i].Addr().Interface()
}
type anySliceArray struct {
type anySliceArrayReflect struct {
slice reflect.Value
}
func (a anySliceArray) Dimensions() []ArrayDimension {
func (a anySliceArrayReflect) Dimensions() []ArrayDimension {
if a.slice.IsNil() {
return nil
}
@ -649,15 +649,15 @@ func (a anySliceArray) Dimensions() []ArrayDimension {
return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}}
}
func (a anySliceArray) Index(i int) any {
func (a anySliceArrayReflect) Index(i int) any {
return a.slice.Index(i).Interface()
}
func (a anySliceArray) IndexType() any {
func (a anySliceArrayReflect) IndexType() any {
return reflect.New(a.slice.Type().Elem()).Elem().Interface()
}
func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error {
func (a *anySliceArrayReflect) SetDimensions(dimensions []ArrayDimension) error {
sliceType := a.slice.Type()
if dimensions == nil {
@ -671,11 +671,11 @@ func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error {
return nil
}
func (a *anySliceArray) ScanIndex(i int) any {
func (a *anySliceArrayReflect) ScanIndex(i int) any {
return a.slice.Index(i).Addr().Interface()
}
func (a *anySliceArray) ScanIndexType() any {
func (a *anySliceArrayReflect) ScanIndexType() any {
return reflect.New(a.slice.Type().Elem()).Interface()
}

View File

@ -993,6 +993,24 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target any) error {
// TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice.
func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) {
// Avoid using reflect path for common types.
switch target := target.(type) {
case *[]int16:
return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true
case *[]int32:
return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true
case *[]int64:
return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true
case *[]float32:
return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true
case *[]float64:
return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true
case *[]string:
return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true
case *[]time.Time:
return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true
}
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return nil, nil, false
@ -1001,19 +1019,29 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa
targetElemValue := targetValue.Elem()
if targetElemValue.Kind() == reflect.Slice {
return &wrapPtrSliceScanPlan{}, &anySliceArray{slice: targetElemValue}, true
return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true
}
return nil, nil, false
}
type wrapPtrSliceScanPlan struct {
type wrapPtrSliceScanPlan[T any] struct {
next ScanPlan
}
func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target any) error {
return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()})
func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error {
return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T)))
}
type wrapPtrSliceReflectScanPlan struct {
next ScanPlan
}
func (plan *wrapPtrSliceReflectScanPlan) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapPtrSliceReflectScanPlan) Scan(src []byte, target any) error {
return plan.next.Scan(src, &anySliceArrayReflect{slice: reflect.ValueOf(target).Elem()})
}
// TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice.
@ -1660,24 +1688,56 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value {
}
func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
// Avoid using reflect path for common types.
switch value := value.(type) {
case []int16:
return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true
case []int32:
return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true
case []int64:
return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true
case []float32:
return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true
case []float64:
return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true
case []string:
return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true
case []time.Time:
return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true
}
if reflect.TypeOf(value).Kind() == reflect.Slice {
w := anySliceArray{
w := anySliceArrayReflect{
slice: reflect.ValueOf(value),
}
return &wrapSliceEncodePlan{}, w, true
return &wrapSliceEncodeReflectPlan{}, w, true
}
return nil, nil, false
}
type wrapSliceEncodePlan struct {
type wrapSliceEncodePlan[T any] struct {
next EncodePlan
}
func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
w := anySliceArray{
func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
w := anySliceArrayReflect{
slice: reflect.ValueOf(value),
}
return plan.next.Encode(w, buf)
}
type wrapSliceEncodeReflectPlan struct {
next EncodePlan
}
func (plan *wrapSliceEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapSliceEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
w := anySliceArrayReflect{
slice: reflect.ValueOf(value),
}