mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Composite() function returns a private type, which should be registered with ConnInfo.RegisterDataType for the composite type's OID. All subsequent interaction with Composite types is to be done via Row(...) function. Function return value can be either passed as a query argument to build SQL composite value out of individual fields or passed to Scan to read SQL composite value back. When passed to Scan, Row() should have first argument of type *bool to flag NULL values returned from query.
525 lines
15 KiB
Go
525 lines
15 KiB
Go
package pgtype
|
|
|
|
import (
|
|
"math"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/jackc/pgio"
|
|
errors "golang.org/x/xerrors"
|
|
)
|
|
|
|
const maxUint = ^uint(0)
|
|
const maxInt = int(maxUint >> 1)
|
|
const minInt = -maxInt - 1
|
|
|
|
// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8
|
|
func underlyingNumberType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
case reflect.Int:
|
|
convVal := int(refVal.Int())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Int8:
|
|
convVal := int8(refVal.Int())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Int16:
|
|
convVal := int16(refVal.Int())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Int32:
|
|
convVal := int32(refVal.Int())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Int64:
|
|
convVal := int64(refVal.Int())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Uint:
|
|
convVal := uint(refVal.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Uint8:
|
|
convVal := uint8(refVal.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Uint16:
|
|
convVal := uint16(refVal.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Uint32:
|
|
convVal := uint32(refVal.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Uint64:
|
|
convVal := uint64(refVal.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Float32:
|
|
convVal := float32(refVal.Float())
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.Float64:
|
|
convVal := refVal.Float()
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
case reflect.String:
|
|
convVal := refVal.String()
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingBoolType gets the underlying type that can be converted to Bool
|
|
func underlyingBoolType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
case reflect.Bool:
|
|
convVal := refVal.Bool()
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingBytesType gets the underlying type that can be converted to []byte
|
|
func underlyingBytesType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
case reflect.Slice:
|
|
if refVal.Type().Elem().Kind() == reflect.Uint8 {
|
|
convVal := refVal.Bytes()
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
}
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingStringType gets the underlying type that can be converted to String
|
|
func underlyingStringType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
case reflect.String:
|
|
convVal := refVal.String()
|
|
return convVal, reflect.TypeOf(convVal) != refVal.Type()
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingPtrType dereferences a pointer
|
|
func underlyingPtrType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingTimeType gets the underlying type that can be converted to time.Time
|
|
func underlyingTimeType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
}
|
|
|
|
timeType := reflect.TypeOf(time.Time{})
|
|
if refVal.Type().ConvertibleTo(timeType) {
|
|
return refVal.Convert(timeType).Interface(), true
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingUUIDType gets the underlying type that can be converted to [16]byte
|
|
func underlyingUUIDType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return time.Time{}, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
}
|
|
|
|
uuidType := reflect.TypeOf([16]byte{})
|
|
if refVal.Type().ConvertibleTo(uuidType) {
|
|
return refVal.Convert(uuidType).Interface(), true
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// underlyingSliceType gets the underlying slice type
|
|
func underlyingSliceType(val interface{}) (interface{}, bool) {
|
|
refVal := reflect.ValueOf(val)
|
|
|
|
switch refVal.Kind() {
|
|
case reflect.Ptr:
|
|
if refVal.IsNil() {
|
|
return nil, false
|
|
}
|
|
convVal := refVal.Elem().Interface()
|
|
return convVal, true
|
|
case reflect.Slice:
|
|
baseSliceType := reflect.SliceOf(refVal.Type().Elem())
|
|
if refVal.Type().ConvertibleTo(baseSliceType) {
|
|
convVal := refVal.Convert(baseSliceType)
|
|
return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
|
|
}
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
|
|
if srcStatus == Present {
|
|
switch v := dst.(type) {
|
|
case *int:
|
|
if srcVal < int64(minInt) {
|
|
return errors.Errorf("%d is less than minimum value for int", srcVal)
|
|
} else if srcVal > int64(maxInt) {
|
|
return errors.Errorf("%d is greater than maximum value for int", srcVal)
|
|
}
|
|
*v = int(srcVal)
|
|
case *int8:
|
|
if srcVal < math.MinInt8 {
|
|
return errors.Errorf("%d is less than minimum value for int8", srcVal)
|
|
} else if srcVal > math.MaxInt8 {
|
|
return errors.Errorf("%d is greater than maximum value for int8", srcVal)
|
|
}
|
|
*v = int8(srcVal)
|
|
case *int16:
|
|
if srcVal < math.MinInt16 {
|
|
return errors.Errorf("%d is less than minimum value for int16", srcVal)
|
|
} else if srcVal > math.MaxInt16 {
|
|
return errors.Errorf("%d is greater than maximum value for int16", srcVal)
|
|
}
|
|
*v = int16(srcVal)
|
|
case *int32:
|
|
if srcVal < math.MinInt32 {
|
|
return errors.Errorf("%d is less than minimum value for int32", srcVal)
|
|
} else if srcVal > math.MaxInt32 {
|
|
return errors.Errorf("%d is greater than maximum value for int32", srcVal)
|
|
}
|
|
*v = int32(srcVal)
|
|
case *int64:
|
|
if srcVal < math.MinInt64 {
|
|
return errors.Errorf("%d is less than minimum value for int64", srcVal)
|
|
} else if srcVal > math.MaxInt64 {
|
|
return errors.Errorf("%d is greater than maximum value for int64", srcVal)
|
|
}
|
|
*v = int64(srcVal)
|
|
case *uint:
|
|
if srcVal < 0 {
|
|
return errors.Errorf("%d is less than zero for uint", srcVal)
|
|
} else if uint64(srcVal) > uint64(maxUint) {
|
|
return errors.Errorf("%d is greater than maximum value for uint", srcVal)
|
|
}
|
|
*v = uint(srcVal)
|
|
case *uint8:
|
|
if srcVal < 0 {
|
|
return errors.Errorf("%d is less than zero for uint8", srcVal)
|
|
} else if srcVal > math.MaxUint8 {
|
|
return errors.Errorf("%d is greater than maximum value for uint8", srcVal)
|
|
}
|
|
*v = uint8(srcVal)
|
|
case *uint16:
|
|
if srcVal < 0 {
|
|
return errors.Errorf("%d is less than zero for uint32", srcVal)
|
|
} else if srcVal > math.MaxUint16 {
|
|
return errors.Errorf("%d is greater than maximum value for uint16", srcVal)
|
|
}
|
|
*v = uint16(srcVal)
|
|
case *uint32:
|
|
if srcVal < 0 {
|
|
return errors.Errorf("%d is less than zero for uint32", srcVal)
|
|
} else if srcVal > math.MaxUint32 {
|
|
return errors.Errorf("%d is greater than maximum value for uint32", srcVal)
|
|
}
|
|
*v = uint32(srcVal)
|
|
case *uint64:
|
|
if srcVal < 0 {
|
|
return errors.Errorf("%d is less than zero for uint64", srcVal)
|
|
}
|
|
*v = uint64(srcVal)
|
|
default:
|
|
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
|
el := v.Elem()
|
|
switch el.Kind() {
|
|
// if dst is a pointer to pointer, strip the pointer and try again
|
|
case reflect.Ptr:
|
|
if el.IsNil() {
|
|
// allocate destination
|
|
el.Set(reflect.New(el.Type().Elem()))
|
|
}
|
|
return int64AssignTo(srcVal, srcStatus, el.Interface())
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
if el.OverflowInt(int64(srcVal)) {
|
|
return errors.Errorf("cannot put %d into %T", srcVal, dst)
|
|
}
|
|
el.SetInt(int64(srcVal))
|
|
return nil
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
if srcVal < 0 {
|
|
return errors.Errorf("%d is less than zero for %T", srcVal, dst)
|
|
}
|
|
if el.OverflowUint(uint64(srcVal)) {
|
|
return errors.Errorf("cannot put %d into %T", srcVal, dst)
|
|
}
|
|
el.SetUint(uint64(srcVal))
|
|
return nil
|
|
}
|
|
}
|
|
return errors.Errorf("cannot assign %v into %T", srcVal, dst)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// if dst is a pointer to pointer and srcStatus is not Present, nil it out
|
|
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
|
el := v.Elem()
|
|
if el.Kind() == reflect.Ptr {
|
|
el.Set(reflect.Zero(el.Type()))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
|
|
}
|
|
|
|
func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
|
|
if srcStatus == Present {
|
|
switch v := dst.(type) {
|
|
case *float32:
|
|
*v = float32(srcVal)
|
|
case *float64:
|
|
*v = srcVal
|
|
default:
|
|
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
|
el := v.Elem()
|
|
switch el.Kind() {
|
|
// if dst is a pointer to pointer, strip the pointer and try again
|
|
case reflect.Ptr:
|
|
if el.IsNil() {
|
|
// allocate destination
|
|
el.Set(reflect.New(el.Type().Elem()))
|
|
}
|
|
return float64AssignTo(srcVal, srcStatus, el.Interface())
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
i64 := int64(srcVal)
|
|
if float64(i64) == srcVal {
|
|
return int64AssignTo(i64, srcStatus, dst)
|
|
}
|
|
}
|
|
}
|
|
return errors.Errorf("cannot assign %v into %T", srcVal, dst)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// if dst is a pointer to pointer and srcStatus is not Present, nil it out
|
|
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
|
el := v.Elem()
|
|
if el.Kind() == reflect.Ptr {
|
|
el.Set(reflect.Zero(el.Type()))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
|
|
}
|
|
|
|
func NullAssignTo(dst interface{}) error {
|
|
dstPtr := reflect.ValueOf(dst)
|
|
|
|
// AssignTo dst must always be a pointer
|
|
if dstPtr.Kind() != reflect.Ptr {
|
|
return errors.Errorf("cannot assign NULL to %T", dst)
|
|
}
|
|
|
|
dstVal := dstPtr.Elem()
|
|
|
|
switch dstVal.Kind() {
|
|
case reflect.Ptr, reflect.Slice, reflect.Map:
|
|
dstVal.Set(reflect.Zero(dstVal.Type()))
|
|
return nil
|
|
}
|
|
|
|
return errors.Errorf("cannot assign NULL to %T", dst)
|
|
}
|
|
|
|
var kindTypes map[reflect.Kind]reflect.Type
|
|
|
|
// GetAssignToDstType attempts to convert dst to something AssignTo can assign
|
|
// to. If dst is a pointer to pointer it allocates a value and returns the
|
|
// dereferences pointer. If dst is a named type such as *Foo where Foo is type
|
|
// Foo int16, it converts dst to *int16.
|
|
//
|
|
// GetAssignToDstType returns the converted dst and a bool representing if any
|
|
// change was made.
|
|
func GetAssignToDstType(dst interface{}) (interface{}, bool) {
|
|
dstPtr := reflect.ValueOf(dst)
|
|
|
|
// AssignTo dst must always be a pointer
|
|
if dstPtr.Kind() != reflect.Ptr {
|
|
return nil, false
|
|
}
|
|
|
|
dstVal := dstPtr.Elem()
|
|
|
|
// if dst is a pointer to pointer, allocate space try again with the dereferenced pointer
|
|
if dstVal.Kind() == reflect.Ptr {
|
|
dstVal.Set(reflect.New(dstVal.Type().Elem()))
|
|
return dstVal.Interface(), true
|
|
}
|
|
|
|
// if dst is pointer to a base type that has been renamed
|
|
if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
|
|
nextDst := dstPtr.Convert(reflect.PtrTo(baseValType))
|
|
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
|
|
}
|
|
|
|
if dstVal.Kind() == reflect.Slice {
|
|
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
|
|
baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType))
|
|
nextDst := dstPtr.Convert(baseSliceType)
|
|
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
|
|
}
|
|
}
|
|
|
|
if dstVal.Kind() == reflect.Array {
|
|
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
|
|
baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))
|
|
nextDst := dstPtr.Convert(baseArrayType)
|
|
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
|
|
}
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// ScanRowValue assigns ROW()'s fields to destination Values.
|
|
// Argument types are checked and error is returned if SQL field value
|
|
// can't be assigned to corresponding destionation Value without loss
|
|
// of information. Number of fields have to match number of destination values.
|
|
//
|
|
// Values must implement BinaryDecoder interface otherwise error is returned.
|
|
// ScanRowValue takes ownership of src, caller MUST not use it after call
|
|
func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error {
|
|
fieldIter, err := newFieldIterator(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(dst) != fieldIter.fieldCount {
|
|
return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldIter.fieldCount, len(dst))
|
|
}
|
|
|
|
_, fieldBytes, eof, err := fieldIter.next()
|
|
for i := 0; !eof; i++ {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
binaryDecoder, ok := dst[i].(BinaryDecoder)
|
|
if !ok {
|
|
return errors.Errorf("record field doesn't implement binary decoding: %s", reflect.TypeOf(dst[i]).Name())
|
|
}
|
|
|
|
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, fieldBytes, eof, err = fieldIter.next()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EncodeRow builds a binary representation of row values (row(), composite types)
|
|
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
|
|
fieldBytes := make([]byte, 0, 128)
|
|
|
|
newBuf = pgio.AppendUint32(buf, uint32(len(fields)))
|
|
for _, f := range fields {
|
|
dt, ok := ci.DataTypeForValue(f)
|
|
if !ok {
|
|
return nil, errors.Errorf("Unknown OID for %s", f)
|
|
}
|
|
newBuf = pgio.AppendUint32(newBuf, dt.OID)
|
|
|
|
if f.Get() != nil {
|
|
binaryEncoder, ok := f.(BinaryEncoder)
|
|
if !ok {
|
|
return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name())
|
|
}
|
|
fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
newBuf = pgio.AppendUint32(newBuf, uint32(len(fieldBytes)))
|
|
newBuf = append(newBuf, fieldBytes...)
|
|
} else {
|
|
newBuf = pgio.AppendInt32(newBuf, int32(-1))
|
|
}
|
|
|
|
}
|
|
return
|
|
}
|
|
|
|
func init() {
|
|
kindTypes = map[reflect.Kind]reflect.Type{
|
|
reflect.Bool: reflect.TypeOf(false),
|
|
reflect.Float32: reflect.TypeOf(float32(0)),
|
|
reflect.Float64: reflect.TypeOf(float64(0)),
|
|
reflect.Int: reflect.TypeOf(int(0)),
|
|
reflect.Int8: reflect.TypeOf(int8(0)),
|
|
reflect.Int16: reflect.TypeOf(int16(0)),
|
|
reflect.Int32: reflect.TypeOf(int32(0)),
|
|
reflect.Int64: reflect.TypeOf(int64(0)),
|
|
reflect.Uint: reflect.TypeOf(uint(0)),
|
|
reflect.Uint8: reflect.TypeOf(uint8(0)),
|
|
reflect.Uint16: reflect.TypeOf(uint16(0)),
|
|
reflect.Uint32: reflect.TypeOf(uint32(0)),
|
|
reflect.Uint64: reflect.TypeOf(uint64(0)),
|
|
reflect.String: reflect.TypeOf(""),
|
|
}
|
|
}
|