diff --git a/pgtype/convert.go b/pgtype/convert.go index 3f3d9e5f..e35e2310 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -122,6 +122,28 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } +func underlyingPtrSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + if refVal.Kind() != reflect.Ptr { + return nil, false + } + if refVal.IsNil() { + return nil, false + } + + sliceVal := refVal.Elem().Interface() + baseSliceType := reflect.SliceOf(reflect.TypeOf(sliceVal).Elem()) + ptrBaseSliceType := reflect.PtrTo(baseSliceType) + + if refVal.Type().ConvertibleTo(ptrBaseSliceType) { + convVal := refVal.Convert(ptrBaseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + + return nil, false +} + func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { if srcStatus == Present { switch v := dst.(type) { diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 4ac0c409..e6809c1e 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -89,6 +89,9 @@ func (src *Int2Array) AssignTo(dst interface{}) error { *v = nil } default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } return fmt.Errorf("cannot put decode %v into %T", src, dst) } diff --git a/pgtype/int2array_test.go b/pgtype/int2array_test.go index 5ea81990..ced0eab4 100644 --- a/pgtype/int2array_test.go +++ b/pgtype/int2array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -50,38 +51,126 @@ func TestInt2ArrayTranscode(t *testing.T) { }) } -// func TestInt2ConvertFrom(t *testing.T) { -// type _int8 int8 +func TestInt2ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2Array + }{ + { + source: []int16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int16)(nil)), + result: pgtype.Int2Array{Status: pgtype.Null}, + }, + } -// successfulTests := []struct { -// source interface{} -// result pgtype.Int2 -// }{ -// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// } + for i, tt := range successfulTests { + var r pgtype.Int2Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } -// for i, tt := range successfulTests { -// var r pgtype.Int2 -// err := r.ConvertFrom(tt.source) -// if err != nil { -// t.Errorf("%d: %v", i, err) -// } + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} -// if r != tt.result { -// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) -// } -// } -// } +func TestInt2ArrayAssignTo(t *testing.T) { + var int16Slice []int16 + var uint16Slice []uint16 + var namedInt16Slice _int16Slice + + simpleTests := []struct { + src pgtype.Int2Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + expected: []int16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + expected: []uint16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt16Slice, + expected: _int16Slice{1}, + }, + { + src: pgtype.Int2Array{Status: pgtype.Null}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2Array + dst interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 32ebebfe..a727e2e5 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -15,6 +15,7 @@ import ( type _bool bool type _int8 int8 type _int16 int16 +type _int16Slice []int16 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))