mirror of https://github.com/jackc/pgx.git
Add shopspring.Numeric
This adds PostgreSQL numeric mapping to and from github.com/shopspring/decimal. Makes pgtype.NullAssignTo public as external types need this functionality. Begin extraction of pgtype testing functionality so it can easily be used by external types.batch-wip
parent
fe7d9d3462
commit
e4451b47b2
|
@ -67,7 +67,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -78,7 +78,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -56,7 +56,7 @@ func (src *Bool) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -79,7 +79,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -61,7 +61,7 @@ func (src *Bytea) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -79,7 +79,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -108,7 +108,7 @@ func (src *CidrArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -342,7 +342,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
|
|||
return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
|
||||
}
|
||||
|
||||
func nullAssignTo(dst interface{}) error {
|
||||
func NullAssignTo(dst interface{}) error {
|
||||
dstPtr := reflect.ValueOf(dst)
|
||||
|
||||
// AssignTo dst must always be a pointer
|
||||
|
|
|
@ -70,7 +70,7 @@ func (src *Date) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -80,7 +80,7 @@ func (src *DateArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -0,0 +1,320 @@
|
|||
package numeric
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
var errUndefined = errors.New("cannot encode status undefined")
|
||||
|
||||
type Numeric struct {
|
||||
Decimal decimal.Decimal
|
||||
Status pgtype.Status
|
||||
}
|
||||
|
||||
func (dst *Numeric) Set(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Numeric{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch value := src.(type) {
|
||||
case decimal.Decimal:
|
||||
*dst = Numeric{Decimal: value, Status: pgtype.Present}
|
||||
case float32:
|
||||
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
|
||||
case float64:
|
||||
*dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present}
|
||||
case int8:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case uint8:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case int16:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case uint16:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case int32:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case uint32:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case int64:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case uint64:
|
||||
// uint64 could be greater than int64 so convert to string then to decimal
|
||||
dec, err := decimal.NewFromString(strconv.FormatUint(value, 10))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
||||
case int:
|
||||
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
||||
case uint:
|
||||
// uint could be greater than int64 so convert to string then to decimal
|
||||
dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
||||
case string:
|
||||
dec, err := decimal.NewFromString(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
||||
default:
|
||||
// If all else fails see if pgtype.Numeric can handle it. If so, translate through that.
|
||||
num := &pgtype.Numeric{}
|
||||
if err := num.Set(value); err != nil {
|
||||
return fmt.Errorf("cannot convert %v to Numeric", value)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
if _, err := num.EncodeText(nil, buf); err != nil {
|
||||
return fmt.Errorf("cannot convert %v to Numeric", value)
|
||||
}
|
||||
|
||||
dec, err := decimal.NewFromString(buf.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to Numeric", value)
|
||||
}
|
||||
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Numeric) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case pgtype.Present:
|
||||
return dst.Decimal
|
||||
case pgtype.Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Numeric) AssignTo(dst interface{}) error {
|
||||
switch src.Status {
|
||||
case pgtype.Present:
|
||||
switch v := dst.(type) {
|
||||
case *decimal.Decimal:
|
||||
*v = src.Decimal
|
||||
case *float32:
|
||||
f, _ := src.Decimal.Float64()
|
||||
*v = float32(f)
|
||||
case *float64:
|
||||
f, _ := src.Decimal.Float64()
|
||||
*v = f
|
||||
case *int:
|
||||
if src.Decimal.Exponent() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = int(n)
|
||||
case *int8:
|
||||
if src.Decimal.Exponent() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseInt(src.Decimal.String(), 10, 8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = int8(n)
|
||||
case *int16:
|
||||
if src.Decimal.Exponent() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseInt(src.Decimal.String(), 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = int16(n)
|
||||
case *int32:
|
||||
if src.Decimal.Exponent() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseInt(src.Decimal.String(), 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = int32(n)
|
||||
case *int64:
|
||||
if src.Decimal.Exponent() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseInt(src.Decimal.String(), 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = int64(n)
|
||||
case *uint:
|
||||
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = uint(n)
|
||||
case *uint8:
|
||||
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseUint(src.Decimal.String(), 10, 8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = uint8(n)
|
||||
case *uint16:
|
||||
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseUint(src.Decimal.String(), 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = uint16(n)
|
||||
case *uint32:
|
||||
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseUint(src.Decimal.String(), 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = uint32(n)
|
||||
case *uint64:
|
||||
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
n, err := strconv.ParseUint(src.Decimal.String(), 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
||||
}
|
||||
*v = uint64(n)
|
||||
default:
|
||||
if nextDst, retry := pgtype.GetAssignToDstType(dst); retry {
|
||||
return src.AssignTo(nextDst)
|
||||
}
|
||||
}
|
||||
case pgtype.Null:
|
||||
return pgtype.NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Numeric{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
dec, err := decimal.NewFromString(string(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Numeric{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For now at least, implement this in terms of pgtype.Numeric
|
||||
|
||||
num := &pgtype.Numeric{}
|
||||
if err := num.DecodeBinary(ci, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
if _, err := num.EncodeText(ci, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dec, err := decimal.NewFromString(buf.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
||||
switch src.Status {
|
||||
case pgtype.Null:
|
||||
return true, nil
|
||||
case pgtype.Undefined:
|
||||
return false, errUndefined
|
||||
}
|
||||
|
||||
_, err := io.WriteString(w, src.Decimal.String())
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
||||
switch src.Status {
|
||||
case pgtype.Null:
|
||||
return true, nil
|
||||
case pgtype.Undefined:
|
||||
return false, errUndefined
|
||||
}
|
||||
|
||||
// For now at least, implement this in terms of pgtype.Numeric
|
||||
num := &pgtype.Numeric{}
|
||||
if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return num.EncodeBinary(ci, w)
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
func (dst *Numeric) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*dst = Numeric{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case float64:
|
||||
*dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present}
|
||||
return nil
|
||||
case string:
|
||||
return dst.DecodeText(nil, []byte(src))
|
||||
case []byte:
|
||||
return dst.DecodeText(nil, src)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the database/sql/driver Valuer interface.
|
||||
func (src *Numeric) Value() (driver.Value, error) {
|
||||
switch src.Status {
|
||||
case pgtype.Present:
|
||||
return src.Decimal.Value()
|
||||
case pgtype.Null:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errUndefined
|
||||
}
|
||||
}
|
|
@ -0,0 +1,281 @@
|
|||
package numeric_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric"
|
||||
"github.com/jackc/pgx/pgtype/testutil"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func mustParseDecimal(t *testing.T, src string) decimal.Decimal {
|
||||
dec, err := decimal.NewFromString(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return dec
|
||||
}
|
||||
|
||||
func TestNumericNormalize(t *testing.T) {
|
||||
testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{
|
||||
{
|
||||
SQL: "select '0'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '1'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '10.00'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '1e-3'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '-1'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '10000'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '3.14'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '1.1'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '100010001'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '100010001.0001'::numeric",
|
||||
Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric",
|
||||
Value: shopspring.Numeric{
|
||||
Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"),
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
},
|
||||
{
|
||||
SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric",
|
||||
Value: shopspring.Numeric{
|
||||
Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"),
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
},
|
||||
{
|
||||
SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric",
|
||||
Value: shopspring.Numeric{
|
||||
Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"),
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestNumericTranscode(t *testing.T) {
|
||||
testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present},
|
||||
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present},
|
||||
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present},
|
||||
&shopspring.Numeric{Status: pgtype.Null},
|
||||
}, func(aa, bb interface{}) bool {
|
||||
a := aa.(shopspring.Numeric)
|
||||
b := bb.(shopspring.Numeric)
|
||||
|
||||
return a.Status == b.Status && a.Decimal.Equal(b.Decimal)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestNumericTranscodeFuzz(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
max := &big.Int{}
|
||||
max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10)
|
||||
|
||||
values := make([]interface{}, 0, 2000)
|
||||
for i := 0; i < 500; i++ {
|
||||
num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String())
|
||||
negNum := "-" + num
|
||||
values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present})
|
||||
values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present})
|
||||
}
|
||||
|
||||
testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values,
|
||||
func(aa, bb interface{}) bool {
|
||||
a := aa.(shopspring.Numeric)
|
||||
b := bb.(shopspring.Numeric)
|
||||
|
||||
return a.Status == b.Status && a.Decimal.Equal(b.Decimal)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNumericSet(t *testing.T) {
|
||||
type _int8 int8
|
||||
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result *shopspring.Numeric
|
||||
}{
|
||||
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
|
||||
{source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
|
||||
{source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
|
||||
{source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}},
|
||||
{source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}},
|
||||
{source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}},
|
||||
{source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}},
|
||||
{source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
r := &shopspring.Numeric{}
|
||||
err := r.Set(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumericAssignTo(t *testing.T) {
|
||||
type _int8 int8
|
||||
|
||||
var i8 int8
|
||||
var i16 int16
|
||||
var i32 int32
|
||||
var i64 int64
|
||||
var i int
|
||||
var ui8 uint8
|
||||
var ui16 uint16
|
||||
var ui32 uint32
|
||||
var ui64 uint64
|
||||
var ui uint
|
||||
var pi8 *int8
|
||||
var _i8 _int8
|
||||
var _pi8 *_int8
|
||||
var f32 float32
|
||||
var f64 float64
|
||||
var pf32 *float32
|
||||
var pf64 *float64
|
||||
|
||||
simpleTests := []struct {
|
||||
src *shopspring.Numeric
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
|
||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
|
||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
pointerAllocTests := []struct {
|
||||
src *shopspring.Numeric
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)},
|
||||
}
|
||||
|
||||
for i, tt := range pointerAllocTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
src *shopspring.Numeric
|
||||
dst interface{}
|
||||
}{
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui},
|
||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -79,7 +79,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -79,7 +79,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -71,7 +71,7 @@ func (src *Hstore) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -79,7 +79,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -90,7 +90,7 @@ func (src *Inet) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -108,7 +108,7 @@ func (src *InetArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -107,7 +107,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -107,7 +107,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -107,7 +107,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -71,7 +71,7 @@ func (src *Interval) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -67,7 +67,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -253,7 +253,7 @@ func (src *Numeric) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -107,7 +107,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -62,7 +62,7 @@ func (src *Record) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
_ "github.com/jackc/pgx/stdlib"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
||||
var sqlDriverName string
|
||||
switch driverName {
|
||||
case "github.com/lib/pq":
|
||||
sqlDriverName = "postgres"
|
||||
case "github.com/jackc/pgx/stdlib":
|
||||
sqlDriverName = "pgx"
|
||||
default:
|
||||
t.Fatalf("Unknown driver %v", driverName)
|
||||
}
|
||||
|
||||
db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func mustConnectPgx(t testing.TB) *pgx.Conn {
|
||||
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn, err := pgx.Connect(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func mustClose(t testing.TB, conn interface {
|
||||
Close() error
|
||||
}) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type forceTextEncoder struct {
|
||||
e pgtype.TextEncoder
|
||||
}
|
||||
|
||||
func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
||||
return f.e.EncodeText(ci, w)
|
||||
}
|
||||
|
||||
type forceBinaryEncoder struct {
|
||||
e pgtype.BinaryEncoder
|
||||
}
|
||||
|
||||
func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) {
|
||||
return f.e.EncodeBinary(ci, w)
|
||||
}
|
||||
|
||||
func forceEncoder(e interface{}, formatCode int16) interface{} {
|
||||
switch formatCode {
|
||||
case pgx.TextFormatCode:
|
||||
if e, ok := e.(pgtype.TextEncoder); ok {
|
||||
return forceTextEncoder{e: e}
|
||||
}
|
||||
case pgx.BinaryFormatCode:
|
||||
if e, ok := e.(pgtype.BinaryEncoder); ok {
|
||||
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
|
||||
TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
|
||||
return reflect.DeepEqual(a, b)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
formatCode int16
|
||||
}{
|
||||
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
for i, v := range values {
|
||||
for _, fc := range formats {
|
||||
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||
vEncoder := forceEncoder(v, fc.formatCode)
|
||||
if vEncoder == nil {
|
||||
t.Logf("Skipping: %#v does not implement %v", v, fc.name)
|
||||
continue
|
||||
}
|
||||
// Derefence value if it is a pointer
|
||||
derefV := v
|
||||
refVal := reflect.ValueOf(v)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", fc.name, i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
for i, v := range values {
|
||||
// Derefence value if it is a pointer
|
||||
derefV := v
|
||||
refVal := reflect.ValueOf(v)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
fmt.Sprintf("select ($1)::%s", pgTypeName),
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
v,
|
||||
).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("Simple protocol %d: %v", i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectDatabaseSQL(t, driverName)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i, v := range values {
|
||||
// Derefence value if it is a pointer
|
||||
derefV := v
|
||||
refVal := reflect.ValueOf(v)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := ps.QueryRow(v).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", driverName, i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type NormalizeTest struct {
|
||||
SQL string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
|
||||
TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
|
||||
return reflect.DeepEqual(a, b)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
||||
TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
formatCode int16
|
||||
}{
|
||||
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
for _, fc := range formats {
|
||||
psName := fmt.Sprintf("test%d", i)
|
||||
ps, err := conn.Prepare(psName, tt.SQL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||
if forceEncoder(tt.Value, fc.formatCode) == nil {
|
||||
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
|
||||
continue
|
||||
}
|
||||
// Derefence value if it is a pointer
|
||||
derefV := tt.Value
|
||||
refVal := reflect.ValueOf(tt.Value)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err = conn.QueryRow(psName).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", fc.name, i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectDatabaseSQL(t, driverName)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
for i, tt := range tests {
|
||||
ps, err := conn.Prepare(tt.SQL)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Derefence value if it is a pointer
|
||||
derefV := tt.Value
|
||||
refVal := reflect.ValueOf(tt.Value)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err = ps.QueryRow().Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", driverName, i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -71,7 +71,7 @@ func (src *Text) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -79,7 +79,7 @@ func (src *TextArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -74,7 +74,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -80,7 +80,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -75,7 +75,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -80,7 +80,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -77,7 +77,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
|
@ -69,7 +69,7 @@ func (src *Uuid) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot assign %v into %T", src, dst)
|
||||
|
|
|
@ -79,7 +79,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error {
|
|||
}
|
||||
}
|
||||
case Null:
|
||||
return nullAssignTo(dst)
|
||||
return NullAssignTo(dst)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
|
|
Loading…
Reference in New Issue