pgx/convert.go
Maxim Ivanov 3ce29f9e05 Add Composite type for inplace row() values handling
Composite() function returns a private type, which should
be registered with ConnInfo.RegisterDataType for the composite
type's OID.

All subsequent interaction with Composite types is to be done
via Row(...) function. Function return value can be either
passed as a query argument to build SQL composite value out of
individual fields or passed to Scan to read SQL composite value
back.

When passed to Scan, Row() should have first argument of type
*bool to flag NULL values returned from query.
2020-04-13 17:41:44 +01:00

525 lines
15 KiB
Go

package pgtype
import (
"math"
"reflect"
"time"
"github.com/jackc/pgio"
errors "golang.org/x/xerrors"
)
const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1)
const minInt = -maxInt - 1
// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8
func underlyingNumberType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Int:
convVal := int(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int8:
convVal := int8(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int16:
convVal := int16(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int32:
convVal := int32(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int64:
convVal := int64(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint:
convVal := uint(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint8:
convVal := uint8(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint16:
convVal := uint16(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint32:
convVal := uint32(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint64:
convVal := uint64(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Float32:
convVal := float32(refVal.Float())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Float64:
convVal := refVal.Float()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.String:
convVal := refVal.String()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
return nil, false
}
// underlyingBoolType gets the underlying type that can be converted to Bool
func underlyingBoolType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Bool:
convVal := refVal.Bool()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
return nil, false
}
// underlyingBytesType gets the underlying type that can be converted to []byte
func underlyingBytesType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Slice:
if refVal.Type().Elem().Kind() == reflect.Uint8 {
convVal := refVal.Bytes()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
}
return nil, false
}
// underlyingStringType gets the underlying type that can be converted to String
func underlyingStringType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.String:
convVal := refVal.String()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
return nil, false
}
// underlyingPtrType dereferences a pointer
func underlyingPtrType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
}
return nil, false
}
// underlyingTimeType gets the underlying type that can be converted to time.Time
func underlyingTimeType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
}
timeType := reflect.TypeOf(time.Time{})
if refVal.Type().ConvertibleTo(timeType) {
return refVal.Convert(timeType).Interface(), true
}
return nil, false
}
// underlyingUUIDType gets the underlying type that can be converted to [16]byte
func underlyingUUIDType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return time.Time{}, false
}
convVal := refVal.Elem().Interface()
return convVal, true
}
uuidType := reflect.TypeOf([16]byte{})
if refVal.Type().ConvertibleTo(uuidType) {
return refVal.Convert(uuidType).Interface(), true
}
return nil, false
}
// underlyingSliceType gets the underlying slice type
func underlyingSliceType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Slice:
baseSliceType := reflect.SliceOf(refVal.Type().Elem())
if refVal.Type().ConvertibleTo(baseSliceType) {
convVal := refVal.Convert(baseSliceType)
return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
}
}
return nil, false
}
func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
if srcStatus == Present {
switch v := dst.(type) {
case *int:
if srcVal < int64(minInt) {
return errors.Errorf("%d is less than minimum value for int", srcVal)
} else if srcVal > int64(maxInt) {
return errors.Errorf("%d is greater than maximum value for int", srcVal)
}
*v = int(srcVal)
case *int8:
if srcVal < math.MinInt8 {
return errors.Errorf("%d is less than minimum value for int8", srcVal)
} else if srcVal > math.MaxInt8 {
return errors.Errorf("%d is greater than maximum value for int8", srcVal)
}
*v = int8(srcVal)
case *int16:
if srcVal < math.MinInt16 {
return errors.Errorf("%d is less than minimum value for int16", srcVal)
} else if srcVal > math.MaxInt16 {
return errors.Errorf("%d is greater than maximum value for int16", srcVal)
}
*v = int16(srcVal)
case *int32:
if srcVal < math.MinInt32 {
return errors.Errorf("%d is less than minimum value for int32", srcVal)
} else if srcVal > math.MaxInt32 {
return errors.Errorf("%d is greater than maximum value for int32", srcVal)
}
*v = int32(srcVal)
case *int64:
if srcVal < math.MinInt64 {
return errors.Errorf("%d is less than minimum value for int64", srcVal)
} else if srcVal > math.MaxInt64 {
return errors.Errorf("%d is greater than maximum value for int64", srcVal)
}
*v = int64(srcVal)
case *uint:
if srcVal < 0 {
return errors.Errorf("%d is less than zero for uint", srcVal)
} else if uint64(srcVal) > uint64(maxUint) {
return errors.Errorf("%d is greater than maximum value for uint", srcVal)
}
*v = uint(srcVal)
case *uint8:
if srcVal < 0 {
return errors.Errorf("%d is less than zero for uint8", srcVal)
} else if srcVal > math.MaxUint8 {
return errors.Errorf("%d is greater than maximum value for uint8", srcVal)
}
*v = uint8(srcVal)
case *uint16:
if srcVal < 0 {
return errors.Errorf("%d is less than zero for uint32", srcVal)
} else if srcVal > math.MaxUint16 {
return errors.Errorf("%d is greater than maximum value for uint16", srcVal)
}
*v = uint16(srcVal)
case *uint32:
if srcVal < 0 {
return errors.Errorf("%d is less than zero for uint32", srcVal)
} else if srcVal > math.MaxUint32 {
return errors.Errorf("%d is greater than maximum value for uint32", srcVal)
}
*v = uint32(srcVal)
case *uint64:
if srcVal < 0 {
return errors.Errorf("%d is less than zero for uint64", srcVal)
}
*v = uint64(srcVal)
default:
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if dst is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
return int64AssignTo(srcVal, srcStatus, el.Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if el.OverflowInt(int64(srcVal)) {
return errors.Errorf("cannot put %d into %T", srcVal, dst)
}
el.SetInt(int64(srcVal))
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if srcVal < 0 {
return errors.Errorf("%d is less than zero for %T", srcVal, dst)
}
if el.OverflowUint(uint64(srcVal)) {
return errors.Errorf("cannot put %d into %T", srcVal, dst)
}
el.SetUint(uint64(srcVal))
return nil
}
}
return errors.Errorf("cannot assign %v into %T", srcVal, dst)
}
return nil
}
// if dst is a pointer to pointer and srcStatus is not Present, nil it out
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
if el.Kind() == reflect.Ptr {
el.Set(reflect.Zero(el.Type()))
return nil
}
}
return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
}
func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
if srcStatus == Present {
switch v := dst.(type) {
case *float32:
*v = float32(srcVal)
case *float64:
*v = srcVal
default:
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if dst is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
return float64AssignTo(srcVal, srcStatus, el.Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i64 := int64(srcVal)
if float64(i64) == srcVal {
return int64AssignTo(i64, srcStatus, dst)
}
}
}
return errors.Errorf("cannot assign %v into %T", srcVal, dst)
}
return nil
}
// if dst is a pointer to pointer and srcStatus is not Present, nil it out
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
if el.Kind() == reflect.Ptr {
el.Set(reflect.Zero(el.Type()))
return nil
}
}
return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
}
func NullAssignTo(dst interface{}) error {
dstPtr := reflect.ValueOf(dst)
// AssignTo dst must always be a pointer
if dstPtr.Kind() != reflect.Ptr {
return errors.Errorf("cannot assign NULL to %T", dst)
}
dstVal := dstPtr.Elem()
switch dstVal.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
dstVal.Set(reflect.Zero(dstVal.Type()))
return nil
}
return errors.Errorf("cannot assign NULL to %T", dst)
}
var kindTypes map[reflect.Kind]reflect.Type
// GetAssignToDstType attempts to convert dst to something AssignTo can assign
// to. If dst is a pointer to pointer it allocates a value and returns the
// dereferences pointer. If dst is a named type such as *Foo where Foo is type
// Foo int16, it converts dst to *int16.
//
// GetAssignToDstType returns the converted dst and a bool representing if any
// change was made.
func GetAssignToDstType(dst interface{}) (interface{}, bool) {
dstPtr := reflect.ValueOf(dst)
// AssignTo dst must always be a pointer
if dstPtr.Kind() != reflect.Ptr {
return nil, false
}
dstVal := dstPtr.Elem()
// if dst is a pointer to pointer, allocate space try again with the dereferenced pointer
if dstVal.Kind() == reflect.Ptr {
dstVal.Set(reflect.New(dstVal.Type().Elem()))
return dstVal.Interface(), true
}
// if dst is pointer to a base type that has been renamed
if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
nextDst := dstPtr.Convert(reflect.PtrTo(baseValType))
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
}
if dstVal.Kind() == reflect.Slice {
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType))
nextDst := dstPtr.Convert(baseSliceType)
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
}
}
if dstVal.Kind() == reflect.Array {
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))
nextDst := dstPtr.Convert(baseArrayType)
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
}
}
return nil, false
}
// ScanRowValue assigns ROW()'s fields to destination Values.
// Argument types are checked and error is returned if SQL field value
// can't be assigned to corresponding destionation Value without loss
// of information. Number of fields have to match number of destination values.
//
// Values must implement BinaryDecoder interface otherwise error is returned.
// ScanRowValue takes ownership of src, caller MUST not use it after call
func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error {
fieldIter, err := newFieldIterator(src)
if err != nil {
return err
}
if len(dst) != fieldIter.fieldCount {
return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldIter.fieldCount, len(dst))
}
_, fieldBytes, eof, err := fieldIter.next()
for i := 0; !eof; i++ {
if err != nil {
return err
}
binaryDecoder, ok := dst[i].(BinaryDecoder)
if !ok {
return errors.Errorf("record field doesn't implement binary decoding: %s", reflect.TypeOf(dst[i]).Name())
}
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil {
return err
}
_, fieldBytes, eof, err = fieldIter.next()
}
return nil
}
// EncodeRow builds a binary representation of row values (row(), composite types)
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
fieldBytes := make([]byte, 0, 128)
newBuf = pgio.AppendUint32(buf, uint32(len(fields)))
for _, f := range fields {
dt, ok := ci.DataTypeForValue(f)
if !ok {
return nil, errors.Errorf("Unknown OID for %s", f)
}
newBuf = pgio.AppendUint32(newBuf, dt.OID)
if f.Get() != nil {
binaryEncoder, ok := f.(BinaryEncoder)
if !ok {
return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name())
}
fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0])
if err != nil {
return nil, err
}
newBuf = pgio.AppendUint32(newBuf, uint32(len(fieldBytes)))
newBuf = append(newBuf, fieldBytes...)
} else {
newBuf = pgio.AppendInt32(newBuf, int32(-1))
}
}
return
}
func init() {
kindTypes = map[reflect.Kind]reflect.Type{
reflect.Bool: reflect.TypeOf(false),
reflect.Float32: reflect.TypeOf(float32(0)),
reflect.Float64: reflect.TypeOf(float64(0)),
reflect.Int: reflect.TypeOf(int(0)),
reflect.Int8: reflect.TypeOf(int8(0)),
reflect.Int16: reflect.TypeOf(int16(0)),
reflect.Int32: reflect.TypeOf(int32(0)),
reflect.Int64: reflect.TypeOf(int64(0)),
reflect.Uint: reflect.TypeOf(uint(0)),
reflect.Uint8: reflect.TypeOf(uint8(0)),
reflect.Uint16: reflect.TypeOf(uint16(0)),
reflect.Uint32: reflect.TypeOf(uint32(0)),
reflect.Uint64: reflect.TypeOf(uint64(0)),
reflect.String: reflect.TypeOf(""),
}
}