Add shopspring.Numeric

This adds PostgreSQL numeric mapping to and from
github.com/shopspring/decimal.

Makes pgtype.NullAssignTo public as external types need this functionality.

Begin extraction of pgtype testing functionality so it can easily be used by
external types.
batch-wip
Jack Christensen 2017-04-14 12:18:49 -05:00
parent fe7d9d3462
commit e4451b47b2
37 changed files with 934 additions and 33 deletions

View File

@ -67,7 +67,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -78,7 +78,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -56,7 +56,7 @@ func (src *Bool) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -79,7 +79,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -61,7 +61,7 @@ func (src *Bytea) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -79,7 +79,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -108,7 +108,7 @@ func (src *CidrArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -342,7 +342,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
} }
func nullAssignTo(dst interface{}) error { func NullAssignTo(dst interface{}) error {
dstPtr := reflect.ValueOf(dst) dstPtr := reflect.ValueOf(dst)
// AssignTo dst must always be a pointer // AssignTo dst must always be a pointer

View File

@ -70,7 +70,7 @@ func (src *Date) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -80,7 +80,7 @@ func (src *DateArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -0,0 +1,320 @@
package numeric
import (
"bytes"
"database/sql/driver"
"errors"
"fmt"
"io"
"strconv"
"github.com/jackc/pgx/pgtype"
"github.com/shopspring/decimal"
)
var errUndefined = errors.New("cannot encode status undefined")
type Numeric struct {
Decimal decimal.Decimal
Status pgtype.Status
}
func (dst *Numeric) Set(src interface{}) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
switch value := src.(type) {
case decimal.Decimal:
*dst = Numeric{Decimal: value, Status: pgtype.Present}
case float32:
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
case float64:
*dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present}
case int8:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint8:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case int16:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint16:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case int32:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint32:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case int64:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint64:
// uint64 could be greater than int64 so convert to string then to decimal
dec, err := decimal.NewFromString(strconv.FormatUint(value, 10))
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
case int:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint:
// uint could be greater than int64 so convert to string then to decimal
dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10))
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
case string:
dec, err := decimal.NewFromString(value)
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
default:
// If all else fails see if pgtype.Numeric can handle it. If so, translate through that.
num := &pgtype.Numeric{}
if err := num.Set(value); err != nil {
return fmt.Errorf("cannot convert %v to Numeric", value)
}
buf := &bytes.Buffer{}
if _, err := num.EncodeText(nil, buf); err != nil {
return fmt.Errorf("cannot convert %v to Numeric", value)
}
dec, err := decimal.NewFromString(buf.String())
if err != nil {
return fmt.Errorf("cannot convert %v to Numeric", value)
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
}
return nil
}
func (dst *Numeric) Get() interface{} {
switch dst.Status {
case pgtype.Present:
return dst.Decimal
case pgtype.Null:
return nil
default:
return dst.Status
}
}
func (src *Numeric) AssignTo(dst interface{}) error {
switch src.Status {
case pgtype.Present:
switch v := dst.(type) {
case *decimal.Decimal:
*v = src.Decimal
case *float32:
f, _ := src.Decimal.Float64()
*v = float32(f)
case *float64:
f, _ := src.Decimal.Float64()
*v = f
case *int:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int(n)
case *int8:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 8)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int8(n)
case *int16:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 16)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int16(n)
case *int32:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 32)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int32(n)
case *int64:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 64)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int64(n)
case *uint:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint(n)
case *uint8:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 8)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint8(n)
case *uint16:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 16)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint16(n)
case *uint32:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 32)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint32(n)
case *uint64:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 64)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint64(n)
default:
if nextDst, retry := pgtype.GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case pgtype.Null:
return pgtype.NullAssignTo(dst)
}
return nil
}
func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
dec, err := decimal.NewFromString(string(src))
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
return nil
}
func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
// For now at least, implement this in terms of pgtype.Numeric
num := &pgtype.Numeric{}
if err := num.DecodeBinary(ci, src); err != nil {
return err
}
buf := &bytes.Buffer{}
if _, err := num.EncodeText(ci, buf); err != nil {
return err
}
dec, err := decimal.NewFromString(buf.String())
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
return nil
}
func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
switch src.Status {
case pgtype.Null:
return true, nil
case pgtype.Undefined:
return false, errUndefined
}
_, err := io.WriteString(w, src.Decimal.String())
return false, err
}
func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
switch src.Status {
case pgtype.Null:
return true, nil
case pgtype.Undefined:
return false, errUndefined
}
// For now at least, implement this in terms of pgtype.Numeric
num := &pgtype.Numeric{}
if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil {
return false, err
}
return num.EncodeBinary(ci, w)
}
// Scan implements the database/sql Scanner interface.
func (dst *Numeric) Scan(src interface{}) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
switch src := src.(type) {
case float64:
*dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Numeric) Value() (driver.Value, error) {
switch src.Status {
case pgtype.Present:
return src.Decimal.Value()
case pgtype.Null:
return nil, nil
default:
return nil, errUndefined
}
}

View File

@ -0,0 +1,281 @@
package numeric_test
import (
"fmt"
"math/big"
"math/rand"
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric"
"github.com/jackc/pgx/pgtype/testutil"
"github.com/shopspring/decimal"
)
func mustParseDecimal(t *testing.T, src string) decimal.Decimal {
dec, err := decimal.NewFromString(src)
if err != nil {
t.Fatal(err)
}
return dec
}
func TestNumericNormalize(t *testing.T) {
testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{
{
SQL: "select '0'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present},
},
{
SQL: "select '1'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present},
},
{
SQL: "select '10.00'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present},
},
{
SQL: "select '1e-3'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present},
},
{
SQL: "select '-1'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present},
},
{
SQL: "select '10000'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present},
},
{
SQL: "select '3.14'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present},
},
{
SQL: "select '1.1'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present},
},
{
SQL: "select '100010001'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present},
},
{
SQL: "select '100010001.0001'::numeric",
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present},
},
{
SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric",
Value: shopspring.Numeric{
Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"),
Status: pgtype.Present,
},
},
{
SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric",
Value: shopspring.Numeric{
Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"),
Status: pgtype.Present,
},
},
{
SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric",
Value: shopspring.Numeric{
Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"),
Status: pgtype.Present,
},
},
})
}
func TestNumericTranscode(t *testing.T) {
testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present},
&shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present},
&shopspring.Numeric{Status: pgtype.Null},
}, func(aa, bb interface{}) bool {
a := aa.(shopspring.Numeric)
b := bb.(shopspring.Numeric)
return a.Status == b.Status && a.Decimal.Equal(b.Decimal)
})
}
func TestNumericTranscodeFuzz(t *testing.T) {
r := rand.New(rand.NewSource(0))
max := &big.Int{}
max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10)
values := make([]interface{}, 0, 2000)
for i := 0; i < 500; i++ {
num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String())
negNum := "-" + num
values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present})
values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present})
}
testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values,
func(aa, bb interface{}) bool {
a := aa.(shopspring.Numeric)
b := bb.(shopspring.Numeric)
return a.Status == b.Status && a.Decimal.Equal(b.Decimal)
})
}
func TestNumericSet(t *testing.T) {
type _int8 int8
successfulTests := []struct {
source interface{}
result *shopspring.Numeric
}{
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
{source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
{source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
{source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
{source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}},
{source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}},
{source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}},
{source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}},
}
for i, tt := range successfulTests {
r := &shopspring.Numeric{}
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestNumericAssignTo(t *testing.T) {
type _int8 int8
var i8 int8
var i16 int16
var i32 int32
var i64 int64
var i int
var ui8 uint8
var ui16 uint16
var ui32 uint32
var ui64 uint64
var ui uint
var pi8 *int8
var _i8 _int8
var _pi8 *_int8
var f32 float32
var f64 float64
var pf32 *float32
var pf64 *float64
simpleTests := []struct {
src *shopspring.Numeric
dst interface{}
expected interface{}
}{
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
pointerAllocTests := []struct {
src *shopspring.Numeric
dst interface{}
expected interface{}
}{
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)},
}
for i, tt := range pointerAllocTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src *shopspring.Numeric
dst interface{}
}{
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}

View File

@ -79,7 +79,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -79,7 +79,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -71,7 +71,7 @@ func (src *Hstore) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -79,7 +79,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -90,7 +90,7 @@ func (src *Inet) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -108,7 +108,7 @@ func (src *InetArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -107,7 +107,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -107,7 +107,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -107,7 +107,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -71,7 +71,7 @@ func (src *Interval) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -67,7 +67,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -253,7 +253,7 @@ func (src *Numeric) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return nil return nil

View File

@ -107,7 +107,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -62,7 +62,7 @@ func (src *Record) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

298
pgtype/testutil/testutil.go Normal file
View File

@ -0,0 +1,298 @@
package testutil
import (
"context"
"database/sql"
"fmt"
"io"
"os"
"reflect"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
_ "github.com/jackc/pgx/stdlib"
_ "github.com/lib/pq"
)
func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
var sqlDriverName string
switch driverName {
case "github.com/lib/pq":
sqlDriverName = "postgres"
case "github.com/jackc/pgx/stdlib":
sqlDriverName = "pgx"
default:
t.Fatalf("Unknown driver %v", driverName)
}
db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL"))
if err != nil {
t.Fatal(err)
}
return db
}
func mustConnectPgx(t testing.TB) *pgx.Conn {
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
if err != nil {
t.Fatal(err)
}
conn, err := pgx.Connect(config)
if err != nil {
t.Fatal(err)
}
return conn
}
func mustClose(t testing.TB, conn interface {
Close() error
}) {
err := conn.Close()
if err != nil {
t.Fatal(err)
}
}
type forceTextEncoder struct {
e pgtype.TextEncoder
}
func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
return f.e.EncodeText(ci, w)
}
type forceBinaryEncoder struct {
e pgtype.BinaryEncoder
}
func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
return f.e.EncodeBinary(ci, w)
}
func forceEncoder(e interface{}, formatCode int16) interface{} {
switch formatCode {
case pgx.TextFormatCode:
if e, ok := e.(pgtype.TextEncoder); ok {
return forceTextEncoder{e: e}
}
case pgx.BinaryFormatCode:
if e, ok := e.(pgtype.BinaryEncoder); ok {
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
}
}
return nil
}
func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
})
}
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
}
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectPgx(t)
defer mustClose(t, conn)
ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName))
if err != nil {
t.Fatal(err)
}
formats := []struct {
name string
formatCode int16
}{
{name: "TextFormat", formatCode: pgx.TextFormatCode},
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
}
for i, v := range values {
for _, fc := range formats {
ps.FieldDescriptions[0].FormatCode = fc.formatCode
vEncoder := forceEncoder(v, fc.formatCode)
if vEncoder == nil {
t.Logf("Skipping: %#v does not implement %v", v, fc.name)
continue
}
// Derefence value if it is a pointer
derefV := v
refVal := reflect.ValueOf(v)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface())
if err != nil {
t.Errorf("%v %d: %v", fc.name, i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
}
}
}
}
func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectPgx(t)
defer mustClose(t, conn)
for i, v := range values {
// Derefence value if it is a pointer
derefV := v
refVal := reflect.ValueOf(v)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err := conn.QueryRowEx(
context.Background(),
fmt.Sprintf("select ($1)::%s", pgTypeName),
&pgx.QueryExOptions{SimpleProtocol: true},
v,
).Scan(result.Interface())
if err != nil {
t.Errorf("Simple protocol %d: %v", i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
}
}
}
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectDatabaseSQL(t, driverName)
defer mustClose(t, conn)
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
if err != nil {
t.Fatal(err)
}
for i, v := range values {
// Derefence value if it is a pointer
derefV := v
refVal := reflect.ValueOf(v)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err := ps.QueryRow(v).Scan(result.Interface())
if err != nil {
t.Errorf("%v %d: %v", driverName, i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
}
}
}
type NormalizeTest struct {
SQL string
Value interface{}
}
func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
})
}
func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
}
}
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
conn := mustConnectPgx(t)
defer mustClose(t, conn)
formats := []struct {
name string
formatCode int16
}{
{name: "TextFormat", formatCode: pgx.TextFormatCode},
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
}
for i, tt := range tests {
for _, fc := range formats {
psName := fmt.Sprintf("test%d", i)
ps, err := conn.Prepare(psName, tt.SQL)
if err != nil {
t.Fatal(err)
}
ps.FieldDescriptions[0].FormatCode = fc.formatCode
if forceEncoder(tt.Value, fc.formatCode) == nil {
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
continue
}
// Derefence value if it is a pointer
derefV := tt.Value
refVal := reflect.ValueOf(tt.Value)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err = conn.QueryRow(psName).Scan(result.Interface())
if err != nil {
t.Errorf("%v %d: %v", fc.name, i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
}
}
}
}
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
conn := mustConnectDatabaseSQL(t, driverName)
defer mustClose(t, conn)
for i, tt := range tests {
ps, err := conn.Prepare(tt.SQL)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
// Derefence value if it is a pointer
derefV := tt.Value
refVal := reflect.ValueOf(tt.Value)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err = ps.QueryRow().Scan(result.Interface())
if err != nil {
t.Errorf("%v %d: %v", driverName, i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
}
}
}

View File

@ -71,7 +71,7 @@ func (src *Text) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -79,7 +79,7 @@ func (src *TextArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -74,7 +74,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -80,7 +80,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -75,7 +75,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -80,7 +80,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -77,7 +77,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

View File

@ -69,7 +69,7 @@ func (src *Uuid) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot assign %v into %T", src, dst) return fmt.Errorf("cannot assign %v into %T", src, dst)

View File

@ -79,7 +79,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error {
} }
} }
case Null: case Null:
return nullAssignTo(dst) return NullAssignTo(dst)
} }
return fmt.Errorf("cannot decode %v into %T", src, dst) return fmt.Errorf("cannot decode %v into %T", src, dst)

2
v3.md
View File

@ -66,3 +66,5 @@ Keep ability to change logging while running
consider test to ensure that AssignTo makes copy of reference types consider test to ensure that AssignTo makes copy of reference types
something like: something like:
select array[1,2,3], array[4,5,6,7] select array[1,2,3], array[4,5,6,7]
Reconsider synonym types like varchar/text and numeric/decimal.