From 53266f029fbb23a31220663ac094869ed4701a0f Mon Sep 17 00:00:00 2001 From: Diego Becciolini Date: Mon, 25 Apr 2022 12:53:15 +0100 Subject: [PATCH 1/9] Hstore: fix AssignTo Hstore.AssignTo a map of string pointers takes the address of the loop variable, thus setting all the entries to the same string pointer. extend TestHstoreAssignToNullable assert fix --- hstore.go | 3 ++- hstore_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hstore.go b/hstore.go index 706a3964..d21af7bc 100644 --- a/hstore.go +++ b/hstore.go @@ -90,7 +90,8 @@ func (src *Hstore) AssignTo(dst interface{}) error { case Null: (*v)[k] = nil case Present: - (*v)[k] = &val.String + str := val.String + (*v)[k] = &str default: return fmt.Errorf("cannot decode %#v into %T", src, dst) } diff --git a/hstore_test.go b/hstore_test.go index 73ee0612..32a8f015 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -181,13 +181,14 @@ func TestHstoreAssignTo(t *testing.T) { func TestHstoreAssignToNullable(t *testing.T) { var m map[string]*string + strPtr := func(str string) *string { return &str } simpleTests := []struct { src pgtype.Hstore dst *map[string]*string expected map[string]*string }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil}}, + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {Status: pgtype.Null}, "bar": {String: "1", Status: pgtype.Present}, "baz": {String: "2", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]*string{"foo": nil, "bar": strPtr("1"), "baz": strPtr("2")}}, {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]*string)(nil))}, } From d846dbcb75b2ac38a1c6fd0390cd18472fe72dee Mon Sep 17 00:00:00 2001 From: Harmen Date: Sun, 24 Apr 2022 08:03:31 +0200 Subject: [PATCH 2/9] allow string values in timestamp[tz].Set() --- date.go | 4 ++-- timestamp.go | 8 ++++++++ timestamp_test.go | 1 + timestamptz.go | 8 ++++++++ timestamptz_test.go | 1 + 5 files changed, 20 insertions(+), 2 deletions(-) diff --git a/date.go b/date.go index e8d21a78..ca84970e 100644 --- a/date.go +++ b/date.go @@ -37,14 +37,14 @@ func (dst *Date) Set(src interface{}) error { switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} - case string: - return dst.DecodeText(nil, []byte(value)) case *time.Time: if value == nil { *dst = Date{Status: Null} } else { return dst.Set(*value) } + case string: + return dst.DecodeText(nil, []byte(value)) case *string: if value == nil { *dst = Date{Status: Null} diff --git a/timestamp.go b/timestamp.go index 5517acb1..e043726d 100644 --- a/timestamp.go +++ b/timestamp.go @@ -46,6 +46,14 @@ func (dst *Timestamp) Set(src interface{}) error { } else { return dst.Set(*value) } + case string: + return dst.DecodeText(nil, []byte(value)) + case *string: + if value == nil { + *dst = Timestamp{Status: Null} + } else { + return dst.Set(*value) + } case InfinityModifier: *dst = Timestamp{InfinityModifier: value, Status: Present} default: diff --git a/timestamp_test.go b/timestamp_test.go index ea7ef57a..d818d4f6 100644 --- a/timestamp_test.go +++ b/timestamp_test.go @@ -123,6 +123,7 @@ func TestTimestampSet(t *testing.T) { {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: "2001-04-05 06:07:08", result: pgtype.Timestamp{Time: time.Date(2001, 4, 5, 6, 7, 8, 0, time.UTC), Status: pgtype.Present}}, } for i, tt := range successfulTests { diff --git a/timestamptz.go b/timestamptz.go index 58701970..72ae4991 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -48,6 +48,14 @@ func (dst *Timestamptz) Set(src interface{}) error { } else { return dst.Set(*value) } + case string: + return dst.DecodeText(nil, []byte(value)) + case *string: + if value == nil { + *dst = Timestamptz{Status: Null} + } else { + return dst.Set(*value) + } case InfinityModifier: *dst = Timestamptz{InfinityModifier: value, Status: Present} default: diff --git a/timestamptz_test.go b/timestamptz_test.go index 2ff326bb..d6a3f518 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -120,6 +120,7 @@ func TestTimestamptzSet(t *testing.T) { {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: "2020-04-05 06:07:08Z", result: pgtype.Timestamptz{Time: time.Date(2020, 4, 5, 6, 7, 8, 0, time.UTC), Status: pgtype.Present}}, } for i, tt := range successfulTests { From 824d8ad40daa6ab015603df0dba250769d6c0653 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Wed, 25 May 2022 09:16:02 -0400 Subject: [PATCH 3/9] support *sql.Scanner for null handling Fixes jackc/pgx#1211 --- pgtype.go | 54 +++++++++++++++++++++++++++++++++++++++++++++----- pgtype_test.go | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/pgtype.go b/pgtype.go index eba09fa5..4078da7b 100644 --- a/pgtype.go +++ b/pgtype.go @@ -533,8 +533,22 @@ type scanPlanDataTypeSQLScanner DataType func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { scanner, ok := dst.(sql.Scanner) if !ok { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + if src == nil { + // Ensure the pointer points to a zero version of the value + dv.Elem().Set(reflect.Zero(dv.Type().Elem())) + return nil + } + dv = dv.Elem() + // If the pointer is to a nil pointer then set that before scanning + if dv.Kind() == reflect.Ptr && dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + scanner = dv.Interface().(sql.Scanner) } dt := (*DataType)(plan) @@ -593,7 +607,25 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanSQLScanner struct{} func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - scanner := dst.(sql.Scanner) + scanner, ok := dst.(sql.Scanner) + if !ok { + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + if src == nil { + // Ensure the pointer points to a zero version of the value + dv.Elem().Set(reflect.Zero(dv.Type())) + return nil + } + dv = dv.Elem() + // If the pointer is to a nil pointer then set that before scanning + if dv.Kind() == reflect.Ptr && dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + scanner = dv.Interface().(sql.Scanner) + } if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. @@ -761,6 +793,18 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func isScanner(dst interface{}) bool { + if _, ok := dst.(sql.Scanner); ok { + return true + } + if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { + return true + } + return false +} + // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { switch formatCode { @@ -825,13 +869,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } if dt != nil { - if _, ok := dst.(sql.Scanner); ok { + if isScanner(dst) { return (*scanPlanDataTypeSQLScanner)(dt) } return (*scanPlanDataTypeAssignTo)(dt) } - if _, ok := dst.(sql.Scanner); ok { + if isScanner(dst) { return scanPlanSQLScanner{} } diff --git a/pgtype_test.go b/pgtype_test.go index 85ca55e9..9127766f 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -310,3 +310,44 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } } + +type pgCustomInt int64 + +func (ci *pgCustomInt) Scan(src interface{}) error { + *ci = pgCustomInt(src.(int64)) + return nil +} + +func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 42} + var v pgCustomInt + + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + ptr := new(pgCustomInt) + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = new(pgCustomInt) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) +} From 2afddedda837064a1c0a42998e27711181884217 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Wed, 1 Jun 2022 10:57:42 -0400 Subject: [PATCH 4/9] protect against panic from PlanScan when interface{}(nil) is passed --- pgtype.go | 2 +- pgtype_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pgtype.go b/pgtype.go index 4078da7b..08e7395e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -799,7 +799,7 @@ func isScanner(dst interface{}) bool { if _, ok := dst.(sql.Scanner); ok { return true } - if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { + if t := reflect.TypeOf(dst); t != nil && t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { return true } return false diff --git a/pgtype_test.go b/pgtype_test.go index 9127766f..67f36373 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -351,3 +351,13 @@ func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { require.NoError(t, err) assert.Nil(t, ptr) } + +// Test for https://github.com/jackc/pgtype/issues/164 +func TestScanPlanInterface(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 42} + var v interface{} + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) + err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, v) + assert.Error(t, err) +} From 6fc738ea05eec3bec8b39e568b99c5ad52ce8073 Mon Sep 17 00:00:00 2001 From: William Storey Date: Fri, 3 Jun 2022 18:00:52 +0000 Subject: [PATCH 5/9] Use correct test description --- inet_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inet_test.go b/inet_test.go index 09c6b21f..8d70c0d0 100644 --- a/inet_test.go +++ b/inet_test.go @@ -68,8 +68,8 @@ func TestInetSet(t *testing.T) { assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) if tt.result.Status == pgtype.Present { - assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) - assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) + assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: Mask", i) + assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: IP", i) } } } From 1e485c1c3b3a757f07124a6a7b71510dd0c008a2 Mon Sep 17 00:00:00 2001 From: William Storey Date: Fri, 3 Jun 2022 18:08:58 +0000 Subject: [PATCH 6/9] Do not send IPv4 networks as IPv4-mapped IPv6 Previously if we provided a parameter that was an array of strings such as []string{"0.0.0.0/8"}, we would encode this when sending to Postgres as ::ffff:0.0.0.0/8. From what I can tell, this is because when parsing the IP/network using net functions, we get a byte array that is 16 bytes long, even if it is an IPv4 network. In Inet.EncodeBinary(), we look at the length of the IP to determine what family the input is, and saw it as IPv6 because of this. We now always normalize IPv4 addresses using To4(). --- inet.go | 19 ++++++++++++++----- inet_test.go | 5 ++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/inet.go b/inet.go index f35f88ba..25e56170 100644 --- a/inet.go +++ b/inet.go @@ -47,17 +47,26 @@ func (dst *Inet) Set(src interface{}) error { case string: ip, ipnet, err := net.ParseCIDR(value) if err != nil { - ip = net.ParseIP(value) + ip := net.ParseIP(value) if ip == nil { return fmt.Errorf("unable to parse inet address: %s", value) } - ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - ipnet.Mask = net.CIDRMask(32, 32) + ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} + } else { + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + } + } else { + ipnet.IP = ip + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + if len(ipnet.Mask) == 16 { + ipnet.Mask = ipnet.Mask[12:] // Needed if input is IPv4-mapped IPv6. + } } } - ipnet.IP = ip + *dst = Inet{IPNet: ipnet, Status: Present} case *net.IPNet: if value == nil { diff --git a/inet_test.go b/inet_test.go index 8d70c0d0..badbf82e 100644 --- a/inet_test.go +++ b/inet_test.go @@ -52,10 +52,12 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4").To4(), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}}, {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, + {source: "0.0.0.0/8", result: pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/8"), Status: pgtype.Present}}, + {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("0.0.0.0").To4(), Mask: net.CIDRMask(8, 32)}, Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -70,6 +72,7 @@ func TestInetSet(t *testing.T) { if tt.result.Status == pgtype.Present { assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: Mask", i) assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: IP", i) + assert.Equalf(t, len(tt.result.IPNet.IP), len(r.IPNet.IP), "%d: IP length", i) } } } From 4db2a33562c6d2d38da9dbe9b8e29f2d4487cc5b Mon Sep 17 00:00:00 2001 From: William Storey Date: Mon, 6 Jun 2022 16:50:43 +0000 Subject: [PATCH 7/9] Do not convert IPv4-mapped IPv6 addresses to IPv4 These addresses behave differently in some cases, so assume if we're given them, we keep them as they are. --- inet.go | 26 +++++++++++++++++++++++--- inet_test.go | 2 +- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/inet.go b/inet.go index 25e56170..a343f5e2 100644 --- a/inet.go +++ b/inet.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "fmt" "net" + "strings" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -52,17 +53,17 @@ func (dst *Inet) Set(src interface{}) error { return fmt.Errorf("unable to parse inet address: %s", value) } - if ipv4 := ip.To4(); ipv4 != nil { + if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil { ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} } else { ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} } } else { ipnet.IP = ip - if ipv4 := ipnet.IP.To4(); ipv4 != nil { + if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil { ipnet.IP = ipv4 if len(ipnet.Mask) == 16 { - ipnet.Mask = ipnet.Mask[12:] // Needed if input is IPv4-mapped IPv6. + ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed. } } } @@ -96,6 +97,25 @@ func (dst *Inet) Set(src interface{}) error { return nil } +// Convert the net.IP to IPv4, if appropriate. +// +// When parsing a string to a net.IP using net.ParseIP() and the like, we get a +// 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function +// calls To4() to convert them to a 4 byte slice. This is useful as it allows +// users of the net.IP check for IPv4 addresses based on the length and makes +// it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6 +// addresses. +func maybeGetIPv4(input string, ip net.IP) net.IP { + // Do not do this if the provided input looks like IPv6. This is because + // To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave + // different in some cases. + if strings.Contains(input, ":") { + return nil + } + + return ip.To4() +} + func (dst Inet) Get() interface{} { switch dst.Status { case Present: diff --git a/inet_test.go b/inet_test.go index badbf82e..52759371 100644 --- a/inet_test.go +++ b/inet_test.go @@ -57,7 +57,7 @@ func TestInetSet(t *testing.T) { {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, {source: "0.0.0.0/8", result: pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/8"), Status: pgtype.Present}}, - {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("0.0.0.0").To4(), Mask: net.CIDRMask(8, 32)}, Status: pgtype.Present}}, + {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("::ffff:0.0.0.0"), Mask: net.CIDRMask(104, 128)}, Status: pgtype.Present}}, } for i, tt := range successfulTests { From 6dd004c8b8f4f938a26778020882139b8f4de1c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Jun 2022 20:40:25 -0500 Subject: [PATCH 8/9] Backport numeric to string from v5 refs https://github.com/jackc/pgx/issues/1230 --- numeric.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/numeric.go b/numeric.go index cd057749..1f32b36b 100644 --- a/numeric.go +++ b/numeric.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" @@ -375,6 +376,12 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } v.Set(rat) + case *string: + buf, err := encodeNumericText(*src, nil) + if err != nil { + return err + } + *v = string(buf) default: if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) @@ -792,3 +799,55 @@ func (src Numeric) Value() (driver.Value, error) { return nil, errUndefined } } + +func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { + // if !n.Valid { + // return nil, nil + // } + + if n.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil + } + + buf = append(buf, n.numberTextBytes()...) + + return buf, nil +} + +// numberString returns a string of the number. undefined if NaN, infinite, or NULL +func (n Numeric) numberTextBytes() []byte { + intStr := n.Int.String() + buf := &bytes.Buffer{} + exp := int(n.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes() +} From 12c49ee213fabc092f24b92db6874ed0d319d7b3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Jun 2022 21:01:56 -0500 Subject: [PATCH 9/9] shopspring-numeric extension does not panic on NaN https://github.com/jackc/pgtype/issues/169 --- ext/shopspring-numeric/decimal.go | 10 ++++++++++ ext/shopspring-numeric/decimal_test.go | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index ef3ce201..c75efa36 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -263,6 +263,16 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } + if num.NaN { + return errors.New("cannot decode 'NaN'") + } + if num.InfinityModifier == pgtype.Infinity { + return errors.New("cannot decode 'Infinity'") + } + if num.InfinityModifier == pgtype.NegativeInfinity { + return errors.New("cannot decode '-Infinity'") + } + *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present} return nil diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index e635da41..e3c6d59d 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -1,6 +1,7 @@ package numeric_test import ( + "context" "fmt" "math/big" "math/rand" @@ -93,6 +94,15 @@ func TestNumericNormalize(t *testing.T) { }) } +func TestNumericNaN(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var n shopspring.Numeric + err := conn.QueryRow(context.Background(), `select 'NaN'::numeric`).Scan(&n) + require.EqualError(t, err, `can't scan into dest[0]: cannot decode 'NaN'`) +} + func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present},