diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index f9faab20..e8386ae7 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" ) // Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -55,39 +54,22 @@ func (dst *Aclitem) Get() interface{} { } func (src *Aclitem) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *string: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.String - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.String: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetString(src.String) - return nil + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index f02d339e..1c97e74f 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -58,28 +58,29 @@ func (dst *AclitemArray) Get() interface{} { } func (src *AclitemArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/bool.go b/pgtype/bool.go index 87316381..608a6f95 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" "strconv" ) @@ -44,39 +43,22 @@ func (dst *Bool) Get() interface{} { } func (src *Bool) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *bool: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Bool - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.Bool: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetBool(src.Bool) - return nil + switch src.Status { + case Present: + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 6adfbb00..cdfe9685 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -59,28 +59,29 @@ func (dst *BoolArray) Get() interface{} { } func (src *BoolArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]bool: - if src.Status == Present { + case *[]bool: *v = make([]bool, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/bytea.go b/pgtype/bytea.go index dc1e9c07..00bed8e8 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "fmt" "io" - "reflect" ) type Bytea struct { @@ -42,38 +41,24 @@ func (dst *Bytea) Get() interface{} { } func (src *Bytea) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *[]byte: - if src.Status == Present { - *v = src.Bytes - } else { - *v = nil - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) - } + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index d318fa3b..175ca2f6 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -59,28 +59,29 @@ func (dst *ByteaArray) Get() interface{} { } func (src *ByteaArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[][]byte: - if src.Status == Present { + case *[][]byte: *v = make([][]byte, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 3ab83ecd..49a2728b 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -79,40 +79,38 @@ func (dst *CidrArray) Get() interface{} { } func (src *CidrArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]*net.IPNet: - if src.Status == Present { + case *[]*net.IPNet: *v = make([]*net.IPNet, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]net.IP: - if src.Status == Present { + case *[]net.IP: *v = make([]net.IP, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/convert.go b/pgtype/convert.go index 648209f5..4fba8430 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -184,28 +184,6 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } -func underlyingPtrSliceType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - if refVal.Kind() != reflect.Ptr { - return nil, false - } - if refVal.IsNil() { - return nil, false - } - - sliceVal := refVal.Elem().Interface() - baseSliceType := reflect.SliceOf(reflect.TypeOf(sliceVal).Elem()) - ptrBaseSliceType := reflect.PtrTo(baseSliceType) - - if refVal.Type().ConvertibleTo(ptrBaseSliceType) { - convVal := refVal.Convert(ptrBaseSliceType) - return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() - } - - return nil, false -} - func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { if srcStatus == Present { switch v := dst.(type) { @@ -363,3 +341,83 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } + +func nullAssignTo(dst interface{}) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return fmt.Errorf("cannot assign NULL to %T", dst) + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return fmt.Errorf("cannot assign NULL to %T", dst) +} + +var kindTypes map[reflect.Kind]reflect.Type + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst interface{}) (interface{}, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) + nextDst := dstPtr.Convert(baseSliceType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/pgtype/date.go b/pgtype/date.go index b6cc8329..ab854eb2 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -50,33 +49,25 @@ func (dst *Date) Get() interface{} { } func (src *Date) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 8bc8ff72..bf791677 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -60,28 +60,29 @@ func (dst *DateArray) Get() interface{} { } func (src *DateArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 6abc1a31..b4d05c55 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -59,28 +59,29 @@ func (dst *Float4Array) Get() interface{} { } func (src *Float4Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]float32: - if src.Status == Present { + case *[]float32: *v = make([]float32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 050efa3f..e000807e 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -59,28 +59,29 @@ func (dst *Float8Array) Get() interface{} { } func (src *Float8Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]float64: - if src.Status == Present { + case *[]float64: *v = make([]float64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/hstore.go b/pgtype/hstore.go index d771d6e6..8dc5b4d8 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -47,10 +47,10 @@ func (dst *Hstore) Get() interface{} { } func (src *Hstore) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *map[string]string: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *map[string]string: *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { @@ -58,16 +58,17 @@ func (src *Hstore) AssignTo(dst interface{}) error { } (*v)[k] = val.String } - case Null: - *v = nil + return nil default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index ba192462..9bd0ed3b 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -59,28 +59,29 @@ func (dst *HstoreArray) Get() interface{} { } func (src *HstoreArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]map[string]string: - if src.Status == Present { + case *[]map[string]string: *v = make([]map[string]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/inet.go b/pgtype/inet.go index b83bd1c9..13764814 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "net" - "reflect" "github.com/jackc/pgx/pgio" ) @@ -61,43 +60,28 @@ func (dst *Inet) Get() interface{} { } func (src *Inet) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *net.IPNet: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = *src.IPNet - case *net.IP: - if src.Status == Present { - + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.IPNet: + *v = *src.IPNet + return nil + case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.IPNet.IP - } else { - *v = nil - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index d893a724..1988a145 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -79,40 +79,38 @@ func (dst *InetArray) Get() interface{} { } func (src *InetArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]*net.IPNet: - if src.Status == Present { + case *[]*net.IPNet: *v = make([]*net.IPNet, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]net.IP: - if src.Status == Present { + case *[]net.IP: *v = make([]net.IP, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index b93a4fa3..531e7dd6 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -78,40 +78,38 @@ func (dst *Int2Array) Get() interface{} { } func (src *Int2Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int16: - if src.Status == Present { + case *[]int16: *v = make([]int16, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint16: - if src.Status == Present { + case *[]uint16: *v = make([]uint16, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 0b96b7a4..3617050f 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -78,40 +78,38 @@ func (dst *Int4Array) Get() interface{} { } func (src *Int4Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int32: - if src.Status == Present { + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint32: - if src.Status == Present { + case *[]uint32: *v = make([]uint32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 02a240f4..4f04b660 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -78,40 +78,38 @@ func (dst *Int8Array) Get() interface{} { } func (src *Int8Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int64: - if src.Status == Present { + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint64: - if src.Status == Present { + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/record.go b/pgtype/record.go index 1bfd05b9..89e081ca 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -38,34 +38,29 @@ func (dst *Record) Get() interface{} { } func (src *Record) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *[]Value: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]Value: *v = make([]Value, len(src.Fields)) copy(*v, src.Fields) - case Null: - *v = nil - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) - } - case *[]interface{}: - switch src.Status { - case Present: + return nil + case *[]interface{}: *v = make([]interface{}, len(src.Fields)) for i := range *v { (*v)[i] = src.Fields[i].Get() } - case Null: - *v = nil + return nil default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { diff --git a/pgtype/text.go b/pgtype/text.go index af7f16fc..dbc9362b 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" ) type Text struct { @@ -43,49 +42,26 @@ func (dst *Text) Get() interface{} { } func (src *Text) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *string: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.String - case *[]byte: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: *v = make([]byte, len(src.String)) copy(*v, src.String) - case Null: - *v = nil + return nil default: - return fmt.Errorf("unknown status") - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.String: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetString(src.String) - return nil + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 9f25727e..6e8ead26 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -59,28 +59,29 @@ func (dst *TextArray) Get() interface{} { } func (src *TextArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 9a9e74ea..4b42f3cf 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -54,33 +53,25 @@ func (dst *Timestamp) Get() interface{} { } func (src *Timestamp) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot assign %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index bb19e502..6a6950c7 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -60,28 +60,29 @@ func (dst *TimestampArray) Get() interface{} { } func (src *TimestampArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 7f57f4b7..ba849ac8 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -55,33 +54,25 @@ func (dst *Timestamptz) Get() interface{} { } func (src *Timestamptz) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot assign %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 6a85cefa..347d0b8b 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -60,28 +60,29 @@ func (dst *TimestamptzArray) Get() interface{} { } func (src *TimestamptzArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 2b81666e..26c4671c 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -58,28 +58,29 @@ func (dst *<%= pgtype_array_type %>) Get() interface{} { } func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - if src.Status == Present { + switch src.Status { + case Present: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: *v = make(<%= t %>, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil + return nil + <% end %> + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - <% end %> - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) - } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 158ece94..e1dd3910 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -59,28 +59,29 @@ func (dst *VarcharArray) Get() interface{} { } func (src *VarcharArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error {