diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 77e385e6..3ccf8318 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -67,7 +67,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 20a7636a..7ef76573 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -78,7 +78,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bool.go b/pgtype/bool.go index 736d19cf..1ebf590b 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -56,7 +56,7 @@ func (src *Bool) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4705d734..468f6816 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -79,7 +79,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 9f0266e7..8bf5de2b 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -61,7 +61,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 268364c1..4aa2b862 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -79,7 +79,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 6643bb47..96d912ae 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -108,7 +108,7 @@ func (src *CidrArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/convert.go b/pgtype/convert.go index 4fba8430..2b406426 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -342,7 +342,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } -func nullAssignTo(dst interface{}) error { +func NullAssignTo(dst interface{}) error { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer diff --git a/pgtype/date.go b/pgtype/date.go index 7dd2c4f0..34753f05 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -70,7 +70,7 @@ func (src *Date) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/date_array.go b/pgtype/date_array.go index f58de011..f24bf6b9 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -80,7 +80,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/ext/shopspring-numeric/decimal.go b/pgtype/ext/shopspring-numeric/decimal.go new file mode 100644 index 00000000..9c7e316b --- /dev/null +++ b/pgtype/ext/shopspring-numeric/decimal.go @@ -0,0 +1,320 @@ +package numeric + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgtype" + "github.com/shopspring/decimal" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Numeric struct { + Decimal decimal.Decimal + Status pgtype.Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch value := src.(type) { + case decimal.Decimal: + *dst = Numeric{Decimal: value, Status: pgtype.Present} + case float32: + *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present} + case int8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int64: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint64: + // uint64 could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case int: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint: + // uint could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case string: + dec, err := decimal.NewFromString(value) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. + num := &pgtype.Numeric{} + if err := num.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + buf := &bytes.Buffer{} + if _, err := num.EncodeText(nil, buf); err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + dec, err := decimal.NewFromString(buf.String()) + if err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.Decimal + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *decimal.Decimal: + *v = src.Decimal + case *float32: + f, _ := src.Decimal.Float64() + *v = float32(f) + case *float64: + f, _ := src.Decimal.Float64() + *v = f + case *int: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int(n) + case *int8: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int8(n) + case *int16: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int16(n) + case *int32: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int32(n) + case *int64: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int64(n) + case *uint: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint(n) + case *uint8: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint8(n) + case *uint16: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint16(n) + case *uint32: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint32(n) + case *uint64: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint64(n) + default: + if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + dec, err := decimal.NewFromString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + return nil +} + +func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + // For now at least, implement this in terms of pgtype.Numeric + + num := &pgtype.Numeric{} + if err := num.DecodeBinary(ci, src); err != nil { + return err + } + + buf := &bytes.Buffer{} + if _, err := num.EncodeText(ci, buf); err != nil { + return err + } + + dec, err := decimal.NewFromString(buf.String()) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + + return nil +} + +func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.Decimal.String()) + return false, err +} + +func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + // For now at least, implement this in terms of pgtype.Numeric + num := &pgtype.Numeric{} + if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { + return false, err + } + + return num.EncodeBinary(ci, w) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case pgtype.Present: + return src.Decimal.Value() + case pgtype.Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/ext/shopspring-numeric/decimal_test.go b/pgtype/ext/shopspring-numeric/decimal_test.go new file mode 100644 index 00000000..50c0fb8b --- /dev/null +++ b/pgtype/ext/shopspring-numeric/decimal_test.go @@ -0,0 +1,281 @@ +package numeric_test + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgx/pgtype/testutil" + "github.com/shopspring/decimal" +) + +func mustParseDecimal(t *testing.T, src string) decimal.Decimal { + dec, err := decimal.NewFromString(src) + if err != nil { + t.Fatal(err) + } + return dec +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, + &shopspring.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 500; i++ { + num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) + negNum := "-" + num + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericSet(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result *shopspring.Numeric + }{ + {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, + {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, + {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, + {source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &shopspring.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + type _int8 int8 + + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *shopspring.Numeric + dst interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index b9ee4b9e..db1523f0 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -79,7 +79,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index d49f18a7..19878bbb 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -79,7 +79,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/hstore.go b/pgtype/hstore.go index b8b0c6f3..5dc78671 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -71,7 +71,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 097fec7b..e4263f20 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -79,7 +79,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/inet.go b/pgtype/inet.go index 3e00e2fa..09fce04d 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -90,7 +90,7 @@ func (src *Inet) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index a108d75b..4687b145 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -108,7 +108,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index bddb5ac2..3506370e 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -107,7 +107,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index d5c8f911..e4ec6455 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -107,7 +107,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index ae2521fa..6c0dab65 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -107,7 +107,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/interval.go b/pgtype/interval.go index 7eddb10f..20a4a419 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -71,7 +71,7 @@ func (src *Interval) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 2d09ff8c..2834d69f 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -67,7 +67,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index a26e8c89..63f99c06 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -253,7 +253,7 @@ func (src *Numeric) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return nil diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index b147e6a2..3d59a6b0 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -107,7 +107,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/record.go b/pgtype/record.go index 9c42c907..3b315d40 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -62,7 +62,7 @@ func (src *Record) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go new file mode 100644 index 00000000..610f0710 --- /dev/null +++ b/pgtype/testutil/testutil.go @@ -0,0 +1,298 @@ +package testutil + +import ( + "context" + "database/sql" + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" +) + +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } + case pgx.BinaryFormatCode: + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } + } + return nil +} + +func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("Skipping: %#v does not implement %v", v, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + +func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type NormalizeTest struct { + SQL string + Value interface{} +} + +func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { + TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.SQL) + if err != nil { + t.Fatal(err) + } + + ps.FieldDescriptions[0].FormatCode = fc.formatCode + if forceEncoder(tt.Value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(psName).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.SQL) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} diff --git a/pgtype/text.go b/pgtype/text.go index 62158b09..de80dd08 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -71,7 +71,7 @@ func (src *Text) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 64728048..a6bd4724 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -79,7 +79,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 78c6355e..e7bc1c7d 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -74,7 +74,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 5d08f9cc..2046c387 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -80,7 +80,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 50370335..ef2d7498 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -75,7 +75,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 107be06a..fd58d3be 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -80,7 +80,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 4b8f1a28..2a38ed82 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -77,7 +77,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 111bed35..88d2195b 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -69,7 +69,7 @@ func (src *Uuid) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot assign %v into %T", src, dst) diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 2712b4d2..9ca16d7e 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -79,7 +79,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/v3.md b/v3.md index d9017890..2946bcf0 100644 --- a/v3.md +++ b/v3.md @@ -66,3 +66,5 @@ Keep ability to change logging while running consider test to ensure that AssignTo makes copy of reference types something like: select array[1,2,3], array[4,5,6,7] + +Reconsider synonym types like varchar/text and numeric/decimal.