diff --git a/pgtype/bool.go b/pgtype/bool.go index d645780d..81c72472 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -3,6 +3,7 @@ package pgtype import ( "fmt" "io" + "reflect" "strconv" "github.com/jackc/pgx/pgio" @@ -36,6 +37,41 @@ func (b *Bool) ConvertFrom(src interface{}) error { } func (b *Bool) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + *v = b.Bool + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if b.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return b.AssignTo(el.Interface()) + case reflect.Bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + el.SetBool(b.Bool) + return nil + } + } + return fmt.Errorf("cannot put decode %v into %T", b, dst) + } + return nil } diff --git a/pgtype/convert.go b/pgtype/convert.go index 5d26669d..3f3d9e5f 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -1,10 +1,16 @@ package pgtype import ( + "fmt" + "math" "reflect" "time" ) +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + // underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 func underlyingIntType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) @@ -115,3 +121,119 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/pgtype/date.go b/pgtype/date.go index e38b2137..3bbcff4d 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -3,6 +3,7 @@ package pgtype import ( "fmt" "io" + "reflect" "time" "github.com/jackc/pgx/pgio" @@ -36,6 +37,35 @@ func (d *Date) ConvertFrom(src interface{}) error { } func (d *Date) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if d.Status != Present { + return fmt.Errorf("cannot assign %v to %T", d, dst) + } + *v = d.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if d.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return d.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", d, dst) + } + return nil } diff --git a/pgtype/int2.go b/pgtype/int2.go index 636ea1f1..2da8a96d 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -85,7 +85,7 @@ func (i *Int2) ConvertFrom(src interface{}) error { } func (i *Int2) AssignTo(dst interface{}) error { - return nil + return int64AssignTo(int64(i.Int), i.Status, dst) } func (i *Int2) DecodeText(r io.Reader) error { diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 43bbccbd..86375516 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -65,6 +65,33 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { } func (a *Int2Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]int16: + if a.Status == Present { + *v = make([]int16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]uint16: + if a.Status == Present { + *v = make([]uint16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + default: + return fmt.Errorf("cannot put decode %v into %T", a, dst) + } + return nil } diff --git a/pgtype/int4.go b/pgtype/int4.go index 7f797e0f..84c45522 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -76,7 +76,7 @@ func (i *Int4) ConvertFrom(src interface{}) error { } func (i *Int4) AssignTo(dst interface{}) error { - return nil + return int64AssignTo(int64(i.Int), i.Status, dst) } func (i *Int4) DecodeText(r io.Reader) error { diff --git a/pgtype/int8.go b/pgtype/int8.go index 5cabb163..c0e14e44 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -67,7 +67,7 @@ func (i *Int8) ConvertFrom(src interface{}) error { } func (i *Int8) AssignTo(dst interface{}) error { - return nil + return int64AssignTo(int64(i.Int), i.Status, dst) } func (i *Int8) DecodeText(r io.Reader) error { diff --git a/query.go b/query.go index db99cddd..4af1de10 100644 --- a/query.go +++ b/query.go @@ -4,9 +4,10 @@ import ( "database/sql" "errors" "fmt" - "golang.org/x/net/context" "time" + "golang.org/x/net/context" + "github.com/jackc/pgx/pgtype" ) @@ -288,8 +289,39 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { d2 := d decodeJSONB(vr, &d2) } else { - if err := Decode(vr, d); err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { + switch vr.Type().FormatCode { + case TextFormatCode: + if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { + vr.err = errRewoundLen + err = textDecoder.DecodeText(&valueReader2{vr}) + if err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal)) + } + case BinaryFormatCode: + if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { + vr.err = errRewoundLen + err = binaryDecoder.DecodeBinary(&valueReader2{vr}) + if err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal)) + } + default: + vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) + } + + if err := pgVal.AssignTo(d); err != nil { + vr.Fatal(err) + } + } else { + if err := Decode(vr, d); err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } } if vr.Err() != nil { diff --git a/query_test.go b/query_test.go index 1bec4f37..fd5d2e5b 100644 --- a/query_test.go +++ b/query_test.go @@ -111,7 +111,7 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) var s string err := conn.QueryRow("select 1").Scan(&s) - if err == nil || !strings.Contains(err.Error(), "cannot decode binary value into string") { + if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) } @@ -200,7 +200,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" { + if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -542,7 +542,7 @@ func TestQueryRowCoreTypes(t *testing.T) { if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") && !strings.Contains(err.Error(), "cannot assign") { t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) } @@ -1018,7 +1018,7 @@ func TestQueryRowCoreInt16Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) { t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) }