mirror of https://github.com/jackc/pgx.git
Implement AssignTo for most pgtypes
parent
cc3d1e4af8
commit
7329933610
|
@ -3,6 +3,7 @@ package pgtype
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
|
@ -36,6 +37,41 @@ func (b *Bool) ConvertFrom(src interface{}) error {
|
|||
}
|
||||
|
||||
func (b *Bool) AssignTo(dst interface{}) error {
|
||||
switch v := dst.(type) {
|
||||
case *bool:
|
||||
if b.Status != Present {
|
||||
return fmt.Errorf("cannot assign %v to %T", b, dst)
|
||||
}
|
||||
*v = b.Bool
|
||||
default:
|
||||
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
switch el.Kind() {
|
||||
// if dst is a pointer to pointer, strip the pointer and try again
|
||||
case reflect.Ptr:
|
||||
if b.Status == Null {
|
||||
if !el.IsNil() {
|
||||
// if the destination pointer is not nil, nil it out
|
||||
el.Set(reflect.Zero(el.Type()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if el.IsNil() {
|
||||
// allocate destination
|
||||
el.Set(reflect.New(el.Type().Elem()))
|
||||
}
|
||||
return b.AssignTo(el.Interface())
|
||||
case reflect.Bool:
|
||||
if b.Status != Present {
|
||||
return fmt.Errorf("cannot assign %v to %T", b, dst)
|
||||
}
|
||||
el.SetBool(b.Bool)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("cannot put decode %v into %T", b, dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxUint = ^uint(0)
|
||||
const maxInt = int(maxUint >> 1)
|
||||
const minInt = -maxInt - 1
|
||||
|
||||
// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8
|
||||
func underlyingIntType(val interface{}) (interface{}, bool) {
|
||||
refVal := reflect.ValueOf(val)
|
||||
|
@ -115,3 +121,119 @@ func underlyingSliceType(val interface{}) (interface{}, bool) {
|
|||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
|
||||
if srcStatus == Present {
|
||||
switch v := dst.(type) {
|
||||
case *int:
|
||||
if srcVal < int64(minInt) {
|
||||
return fmt.Errorf("%d is less than minimum value for int", srcVal)
|
||||
} else if srcVal > int64(maxInt) {
|
||||
return fmt.Errorf("%d is greater than maximum value for int", srcVal)
|
||||
}
|
||||
*v = int(srcVal)
|
||||
case *int8:
|
||||
if srcVal < math.MinInt8 {
|
||||
return fmt.Errorf("%d is less than minimum value for int8", srcVal)
|
||||
} else if srcVal > math.MaxInt8 {
|
||||
return fmt.Errorf("%d is greater than maximum value for int8", srcVal)
|
||||
}
|
||||
*v = int8(srcVal)
|
||||
case *int16:
|
||||
if srcVal < math.MinInt16 {
|
||||
return fmt.Errorf("%d is less than minimum value for int16", srcVal)
|
||||
} else if srcVal > math.MaxInt16 {
|
||||
return fmt.Errorf("%d is greater than maximum value for int16", srcVal)
|
||||
}
|
||||
*v = int16(srcVal)
|
||||
case *int32:
|
||||
if srcVal < math.MinInt32 {
|
||||
return fmt.Errorf("%d is less than minimum value for int32", srcVal)
|
||||
} else if srcVal > math.MaxInt32 {
|
||||
return fmt.Errorf("%d is greater than maximum value for int32", srcVal)
|
||||
}
|
||||
*v = int32(srcVal)
|
||||
case *int64:
|
||||
if srcVal < math.MinInt64 {
|
||||
return fmt.Errorf("%d is less than minimum value for int64", srcVal)
|
||||
} else if srcVal > math.MaxInt64 {
|
||||
return fmt.Errorf("%d is greater than maximum value for int64", srcVal)
|
||||
}
|
||||
*v = int64(srcVal)
|
||||
case *uint:
|
||||
if srcVal < 0 {
|
||||
return fmt.Errorf("%d is less than zero for uint", srcVal)
|
||||
} else if uint64(srcVal) > uint64(maxUint) {
|
||||
return fmt.Errorf("%d is greater than maximum value for uint", srcVal)
|
||||
}
|
||||
*v = uint(srcVal)
|
||||
case *uint8:
|
||||
if srcVal < 0 {
|
||||
return fmt.Errorf("%d is less than zero for uint8", srcVal)
|
||||
} else if srcVal > math.MaxUint8 {
|
||||
return fmt.Errorf("%d is greater than maximum value for uint8", srcVal)
|
||||
}
|
||||
*v = uint8(srcVal)
|
||||
case *uint16:
|
||||
if srcVal < 0 {
|
||||
return fmt.Errorf("%d is less than zero for uint32", srcVal)
|
||||
} else if srcVal > math.MaxUint16 {
|
||||
return fmt.Errorf("%d is greater than maximum value for uint16", srcVal)
|
||||
}
|
||||
*v = uint16(srcVal)
|
||||
case *uint32:
|
||||
if srcVal < 0 {
|
||||
return fmt.Errorf("%d is less than zero for uint32", srcVal)
|
||||
} else if srcVal > math.MaxUint32 {
|
||||
return fmt.Errorf("%d is greater than maximum value for uint32", srcVal)
|
||||
}
|
||||
*v = uint32(srcVal)
|
||||
case *uint64:
|
||||
if srcVal < 0 {
|
||||
return fmt.Errorf("%d is less than zero for uint64", srcVal)
|
||||
}
|
||||
*v = uint64(srcVal)
|
||||
default:
|
||||
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
switch el.Kind() {
|
||||
// if dst is a pointer to pointer, strip the pointer and try again
|
||||
case reflect.Ptr:
|
||||
if el.IsNil() {
|
||||
// allocate destination
|
||||
el.Set(reflect.New(el.Type().Elem()))
|
||||
}
|
||||
return int64AssignTo(srcVal, srcStatus, el.Interface())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if el.OverflowInt(int64(srcVal)) {
|
||||
return fmt.Errorf("cannot put %d into %T", srcVal, dst)
|
||||
}
|
||||
el.SetInt(int64(srcVal))
|
||||
return nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if srcVal < 0 {
|
||||
return fmt.Errorf("%d is less than zero for %T", srcVal, dst)
|
||||
}
|
||||
if el.OverflowUint(uint64(srcVal)) {
|
||||
return fmt.Errorf("cannot put %d into %T", srcVal, dst)
|
||||
}
|
||||
el.SetUint(uint64(srcVal))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// if dst is a pointer to pointer and srcStatus is not Present, nil it out
|
||||
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
if el.Kind() == reflect.Ptr {
|
||||
el.Set(reflect.Zero(el.Type()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package pgtype
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
|
@ -36,6 +37,35 @@ func (d *Date) ConvertFrom(src interface{}) error {
|
|||
}
|
||||
|
||||
func (d *Date) AssignTo(dst interface{}) error {
|
||||
switch v := dst.(type) {
|
||||
case *time.Time:
|
||||
if d.Status != Present {
|
||||
return fmt.Errorf("cannot assign %v to %T", d, dst)
|
||||
}
|
||||
*v = d.Time
|
||||
default:
|
||||
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
switch el.Kind() {
|
||||
// if dst is a pointer to pointer, strip the pointer and try again
|
||||
case reflect.Ptr:
|
||||
if d.Status == Null {
|
||||
if !el.IsNil() {
|
||||
// if the destination pointer is not nil, nil it out
|
||||
el.Set(reflect.Zero(el.Type()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if el.IsNil() {
|
||||
// allocate destination
|
||||
el.Set(reflect.New(el.Type().Elem()))
|
||||
}
|
||||
return d.AssignTo(el.Interface())
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("cannot decode %v into %T", d, dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ func (i *Int2) ConvertFrom(src interface{}) error {
|
|||
}
|
||||
|
||||
func (i *Int2) AssignTo(dst interface{}) error {
|
||||
return nil
|
||||
return int64AssignTo(int64(i.Int), i.Status, dst)
|
||||
}
|
||||
|
||||
func (i *Int2) DecodeText(r io.Reader) error {
|
||||
|
|
|
@ -65,6 +65,33 @@ func (a *Int2Array) ConvertFrom(src interface{}) error {
|
|||
}
|
||||
|
||||
func (a *Int2Array) AssignTo(dst interface{}) error {
|
||||
switch v := dst.(type) {
|
||||
case *[]int16:
|
||||
if a.Status == Present {
|
||||
*v = make([]int16, len(a.Elements))
|
||||
for i := range a.Elements {
|
||||
if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*v = nil
|
||||
}
|
||||
case *[]uint16:
|
||||
if a.Status == Present {
|
||||
*v = make([]uint16, len(a.Elements))
|
||||
for i := range a.Elements {
|
||||
if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*v = nil
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cannot put decode %v into %T", a, dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ func (i *Int4) ConvertFrom(src interface{}) error {
|
|||
}
|
||||
|
||||
func (i *Int4) AssignTo(dst interface{}) error {
|
||||
return nil
|
||||
return int64AssignTo(int64(i.Int), i.Status, dst)
|
||||
}
|
||||
|
||||
func (i *Int4) DecodeText(r io.Reader) error {
|
||||
|
|
|
@ -67,7 +67,7 @@ func (i *Int8) ConvertFrom(src interface{}) error {
|
|||
}
|
||||
|
||||
func (i *Int8) AssignTo(dst interface{}) error {
|
||||
return nil
|
||||
return int64AssignTo(int64(i.Int), i.Status, dst)
|
||||
}
|
||||
|
||||
func (i *Int8) DecodeText(r io.Reader) error {
|
||||
|
|
38
query.go
38
query.go
|
@ -4,9 +4,10 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -288,8 +289,39 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
d2 := d
|
||||
decodeJSONB(vr, &d2)
|
||||
} else {
|
||||
if err := Decode(vr, d); err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present {
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok {
|
||||
vr.err = errRewoundLen
|
||||
err = textDecoder.DecodeText(&valueReader2{vr})
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal))
|
||||
}
|
||||
case BinaryFormatCode:
|
||||
if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok {
|
||||
vr.err = errRewoundLen
|
||||
err = binaryDecoder.DecodeBinary(&valueReader2{vr})
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal))
|
||||
}
|
||||
default:
|
||||
vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode))
|
||||
}
|
||||
|
||||
if err := pgVal.AssignTo(d); err != nil {
|
||||
vr.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
if err := Decode(vr, d); err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
}
|
||||
}
|
||||
if vr.Err() != nil {
|
||||
|
|
|
@ -111,7 +111,7 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T)
|
|||
var s string
|
||||
|
||||
err := conn.QueryRow("select 1").Scan(&s)
|
||||
if err == nil || !strings.Contains(err.Error(), "cannot decode binary value into string") {
|
||||
if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) {
|
||||
t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err)
|
||||
}
|
||||
|
||||
|
@ -200,7 +200,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
|
|||
t.Fatal("Expected Rows to have an error after an improper read but it didn't")
|
||||
}
|
||||
|
||||
if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" {
|
||||
if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") {
|
||||
t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
|
@ -542,7 +542,7 @@ func TestQueryRowCoreTypes(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") && !strings.Contains(err.Error(), "cannot assign") {
|
||||
t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql)
|
||||
}
|
||||
|
||||
|
@ -1018,7 +1018,7 @@ func TestQueryRowCoreInt16Slice(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue