Fix scan array of record to pointer to slice of struct

https://github.com/jackc/pgx/issues/1570
pull/1571/head
Jack Christensen 2023-04-08 14:39:48 -05:00
parent f72a147db3
commit 847f888631
2 changed files with 26 additions and 5 deletions

View File

@ -1089,15 +1089,16 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa
return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true
}
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
targetType := reflect.TypeOf(target)
if targetType.Kind() != reflect.Ptr {
return nil, nil, false
}
targetElemValue := targetValue.Elem()
targetElemType := targetType.Elem()
if targetElemValue.Kind() == reflect.Slice {
return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true
if targetElemType.Kind() == reflect.Slice {
slice := reflect.New(targetElemType).Elem()
return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true
}
return nil, nil, false
}
@ -1198,6 +1199,10 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {
}
func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan {
if target == nil {
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
}
if _, ok := target.(*UndecodedBytes); ok {
return scanPlanAnyToUndecodedBytes{}
}

View File

@ -330,6 +330,22 @@ func TestMapScanPtrToPtrToSlice(t *testing.T) {
require.Equal(t, []string{"foo", "bar"}, *v)
}
func TestMapScanPtrToPtrToSliceOfStruct(t *testing.T) {
type Team struct {
TeamID int
Name string
}
// Have to use binary format because text format doesn't include type information.
m := pgtype.NewMap()
src := []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0xc9, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x19, 0x0, 0x0, 0x0, 0x6, 0x74, 0x65, 0x61, 0x6d, 0x20, 0x31, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x19, 0x0, 0x0, 0x0, 0x6, 0x74, 0x65, 0x61, 0x6d, 0x20, 0x32}
var v *[]Team
plan := m.PlanScan(pgtype.RecordArrayOID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.Equal(t, []Team{{1, "team 1"}, {2, "team 2"}}, *v)
}
type databaseValuerString string
func (s databaseValuerString) Value() (driver.Value, error) {