From 81626342596a5655b2a2239817c45ee29357ebce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 10 Mar 2017 16:08:47 -0600 Subject: [PATCH] Decode(Text|Binary) now accepts []byte instead of io.Reader --- pgtype/array.go | 46 ++++++++++++------------ pgtype/bool.go | 40 ++++++--------------- pgtype/boolarray.go | 52 +++++++++++++-------------- pgtype/bytea.go | 39 +++++--------------- pgtype/cid.go | 8 ++--- pgtype/cidrarray.go | 8 ++--- pgtype/date.go | 36 +++++-------------- pgtype/datearray.go | 52 +++++++++++++-------------- pgtype/float4.go | 36 +++++-------------- pgtype/float4array.go | 52 +++++++++++++-------------- pgtype/float8.go | 36 +++++-------------- pgtype/float8array.go | 52 +++++++++++++-------------- pgtype/inet.go | 60 +++++++------------------------ pgtype/inetarray.go | 52 +++++++++++++-------------- pgtype/int2.go | 39 ++++++-------------- pgtype/int2array.go | 52 +++++++++++++-------------- pgtype/int4.go | 37 +++++-------------- pgtype/int4array.go | 52 +++++++++++++-------------- pgtype/int8.go | 36 +++++-------------- pgtype/int8array.go | 52 +++++++++++++-------------- pgtype/name.go | 8 ++--- pgtype/oid.go | 8 ++--- pgtype/pgtype.go | 4 +-- pgtype/pguint32.go | 37 +++++-------------- pgtype/qchar.go | 20 +++-------- pgtype/text.go | 21 +++-------- pgtype/textarray.go | 53 ++++++++++++++------------- pgtype/timestamp.go | 36 +++++-------------- pgtype/timestamparray.go | 52 +++++++++++++-------------- pgtype/timestamptz.go | 36 +++++-------------- pgtype/timestamptzarray.go | 52 +++++++++++++-------------- pgtype/to-consider.txt | 9 +++++ pgtype/typed_array.go.erb | 58 ++++++++++++------------------ pgtype/varchararray.go | 8 ++--- pgtype/xid.go | 8 ++--- query.go | 12 +++---- value_reader.go | 29 +++------------ values.go | 73 +++++++++++--------------------------- 38 files changed, 506 insertions(+), 855 deletions(-) create mode 100644 pgtype/to-consider.txt diff --git a/pgtype/array.go b/pgtype/array.go index 76492c61..6b705103 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "strconv" @@ -25,40 +26,37 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(r io.Reader) error { - numDims, err := pgio.ReadInt32(r) - if err != nil { - return err +func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { + if len(src) < 12 { + return 0, fmt.Errorf("array header too short: %d", len(src)) } + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if numDims > 0 { dst.Dimensions = make([]ArrayDimension, numDims) } - - containsNull, err := pgio.ReadInt32(r) - if err != nil { - return err + if len(src) < 12+numDims*8 { + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } - dst.ContainsNull = containsNull == 1 - - dst.ElementOID, err = pgio.ReadInt32(r) - if err != nil { - return err - } - for i := range dst.Dimensions { - dst.Dimensions[i].Length, err = pgio.ReadInt32(r) - if err != nil { - return err - } + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 - dst.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) - if err != nil { - return err - } + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 } - return nil + return rp, nil } func (src *ArrayHeader) EncodeBinary(w io.Writer) error { diff --git a/pgtype/bool.go b/pgtype/bool.go index 076403f9..b7bc14d0 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -72,51 +72,31 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bool) DecodeText(src []byte) error { + if src == nil { *dst = Bool{Status: Null} return nil } - if size != 1 { - return fmt.Errorf("invalid length for bool: %v", size) + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = Bool{Bool: byt == 't', Status: Present} + *dst = Bool{Bool: src[0] == 't', Status: Present} return nil } -func (dst *Bool) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bool) DecodeBinary(src []byte) error { + if src == nil { *dst = Bool{Status: Null} return nil } - if size != 1 { - return fmt.Errorf("invalid length for bool: %v", size) + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = Bool{Bool: byt == 1, Status: Present} + *dst = Bool{Bool: src[0] == 1, Status: Present} return nil } diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go index b6b5db02..a9b8bf50 100644 --- a/pgtype/boolarray.go +++ b/pgtype/boolarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *BoolArray) DecodeText(src []byte) error { + if src == nil { *dst = BoolArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Bool if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *BoolArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Bool - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *BoolArray) DecodeText(r io.Reader) error { return nil } -func (dst *BoolArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *BoolArray) DecodeBinary(src []byte) error { + if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *BoolArray) DecodeBinary(r io.Reader) error { elements := make([]Bool, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *BoolArray) EncodeText(w io.Writer) error { } func (src *BoolArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, BoolOID) +} + +func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *BoolArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = BoolOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 2532182f..db20482f 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -71,29 +71,18 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bytea) DecodeText(src []byte) error { + if src == nil { *dst = Bytea{Status: Null} return nil } - sbuf := make([]byte, int(size)) - _, err = io.ReadFull(r, sbuf) - if err != nil { - return err - } - - if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' { + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { return fmt.Errorf("invalid hex format") } - buf := make([]byte, (len(sbuf)-2)/2) - _, err = hex.Decode(buf, sbuf[2:]) + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) if err != nil { return err } @@ -102,25 +91,13 @@ func (dst *Bytea) DecodeText(r io.Reader) error { return nil } -func (dst *Bytea) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bytea) DecodeBinary(src []byte) error { + if src == nil { *dst = Bytea{Status: Null} return nil } - buf := make([]byte, int(size)) - - _, err = io.ReadFull(r, buf) - if err != nil { - return err - } - - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: src, Status: Present} return nil } diff --git a/pgtype/cid.go b/pgtype/cid.go index 21d6fb80..f8d706d0 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -30,12 +30,12 @@ func (src *CID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *CID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *CID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *CID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *CID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src CID) EncodeText(w io.Writer) error { diff --git a/pgtype/cidrarray.go b/pgtype/cidrarray.go index 66dd20d0..d95eef4a 100644 --- a/pgtype/cidrarray.go +++ b/pgtype/cidrarray.go @@ -14,12 +14,12 @@ func (src *CidrArray) AssignTo(dst interface{}) error { return (*InetArray)(src).AssignTo(dst) } -func (dst *CidrArray) DecodeText(r io.Reader) error { - return (*InetArray)(dst).DecodeText(r) +func (dst *CidrArray) DecodeText(src []byte) error { + return (*InetArray)(dst).DecodeText(src) } -func (dst *CidrArray) DecodeBinary(r io.Reader) error { - return (*InetArray)(dst).DecodeBinary(r) +func (dst *CidrArray) DecodeBinary(src []byte) error { + return (*InetArray)(dst).DecodeBinary(src) } func (src *CidrArray) EncodeText(w io.Writer) error { diff --git a/pgtype/date.go b/pgtype/date.go index 307f1e59..1bb81d35 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -66,24 +67,13 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Date) DecodeText(src []byte) error { + if src == nil { *dst = Date{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Date{Status: Present, InfinityModifier: Infinity} @@ -101,25 +91,17 @@ func (dst *Date) DecodeText(r io.Reader) error { return nil } -func (dst *Date) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Date) DecodeBinary(src []byte) error { + if src == nil { *dst = Date{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for date: %v", size) + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) } - dayOffset, err := pgio.ReadInt32(r) - if err != nil { - return err - } + dayOffset := int32(binary.BigEndian.Uint32(src)) switch dayOffset { case infinityDayOffset: diff --git a/pgtype/datearray.go b/pgtype/datearray.go index 5e93501e..e9ad1f62 100644 --- a/pgtype/datearray.go +++ b/pgtype/datearray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *DateArray) DecodeText(src []byte) error { + if src == nil { *dst = DateArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Date if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *DateArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Date - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *DateArray) DecodeText(r io.Reader) error { return nil } -func (dst *DateArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *DateArray) DecodeBinary(src []byte) error { + if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *DateArray) DecodeBinary(r io.Reader) error { elements := make([]Date, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *DateArray) EncodeText(w io.Writer) error { } func (src *DateArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, DateOID) +} + +func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *DateArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = DateOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/float4.go b/pgtype/float4.go index a1e5aa18..fb0415e5 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -92,24 +93,13 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4) DecodeText(src []byte) error { + if src == nil { *dst = Float4{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseFloat(string(buf), 32) + n, err := strconv.ParseFloat(string(src), 32) if err != nil { return err } @@ -118,25 +108,17 @@ func (dst *Float4) DecodeText(r io.Reader) error { return nil } -func (dst *Float4) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4) DecodeBinary(src []byte) error { + if src == nil { *dst = Float4{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for float4: %v", size) + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - n, err := pgio.ReadInt32(r) - if err != nil { - return err - } + n := int32(binary.BigEndian.Uint32(src)) *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} return nil diff --git a/pgtype/float4array.go b/pgtype/float4array.go index 8834d213..a4a72146 100644 --- a/pgtype/float4array.go +++ b/pgtype/float4array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4Array) DecodeText(src []byte) error { + if src == nil { *dst = Float4Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Float4 if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *Float4Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Float4 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *Float4Array) DecodeText(r io.Reader) error { return nil } -func (dst *Float4Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *Float4Array) DecodeBinary(r io.Reader) error { elements := make([]Float4, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *Float4Array) EncodeText(w io.Writer) error { } func (src *Float4Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Float4OID) +} + +func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *Float4Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Float4OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/float8.go b/pgtype/float8.go index c1347cb2..a53de5e3 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -82,24 +83,13 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8) DecodeText(src []byte) error { + if src == nil { *dst = Float8{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseFloat(string(buf), 64) + n, err := strconv.ParseFloat(string(src), 64) if err != nil { return err } @@ -108,25 +98,17 @@ func (dst *Float8) DecodeText(r io.Reader) error { return nil } -func (dst *Float8) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8) DecodeBinary(src []byte) error { + if src == nil { *dst = Float8{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for float4: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - n, err := pgio.ReadInt64(r) - if err != nil { - return err - } + n := int64(binary.BigEndian.Uint64(src)) *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} return nil diff --git a/pgtype/float8array.go b/pgtype/float8array.go index bad9ed9f..082e817d 100644 --- a/pgtype/float8array.go +++ b/pgtype/float8array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8Array) DecodeText(src []byte) error { + if src == nil { *dst = Float8Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Float8 if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *Float8Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Float8 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *Float8Array) DecodeText(r io.Reader) error { return nil } -func (dst *Float8Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *Float8Array) DecodeBinary(r io.Reader) error { elements := make([]Float8, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *Float8Array) EncodeText(w io.Writer) error { } func (src *Float8Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Float8OID) +} + +func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *Float8Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Float8OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/inet.go b/pgtype/inet.go index e47c64b0..132a876a 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -91,26 +91,16 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Inet) DecodeText(src []byte) error { + if src == nil { *dst = Inet{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) - if err != nil { - return err - } - var ipnet *net.IPNet + var err error - if ip := net.ParseIP(string(buf)); ip != nil { + if ip := net.ParseIP(string(src)); ip != nil { ipv4 := ip.To4() if ipv4 != nil { ip = ipv4 @@ -119,7 +109,7 @@ func (dst *Inet) DecodeText(r io.Reader) error { mask := net.CIDRMask(bitCount, bitCount) ipnet = &net.IPNet{Mask: mask, IP: ip} } else { - _, ipnet, err = net.ParseCIDR(string(buf)) + _, ipnet, err = net.ParseCIDR(string(src)) if err != nil { return err } @@ -129,50 +119,24 @@ func (dst *Inet) DecodeText(r io.Reader) error { return nil } -func (dst *Inet) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Inet) DecodeBinary(src []byte) error { + if src == nil { *dst = Inet{Status: Null} return nil } - if size != 8 && size != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", size) + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) } // ignore family - _, err = pgio.ReadByte(r) - if err != nil { - return err - } - - bits, err := pgio.ReadByte(r) - if err != nil { - return err - } - + bits := src[1] // ignore is_cidr - _, err = pgio.ReadByte(r) - if err != nil { - return err - } - - addressLength, err := pgio.ReadByte(r) - if err != nil { - return err - } + addressLength := src[3] var ipnet net.IPNet ipnet.IP = make(net.IP, int(addressLength)) - _, err = r.Read(ipnet.IP) - if err != nil { - return err - } - + copy(ipnet.IP, src[4:]) ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) *dst = Inet{IPNet: &ipnet, Status: Present} diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go index cd12e917..28de736f 100644 --- a/pgtype/inetarray.go +++ b/pgtype/inetarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "net" @@ -19,8 +20,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { switch value := src.(type) { case InetArray: *dst = value - case CidrArray: - *dst = InetArray(value) + case []*net.IPNet: if value == nil { *dst = InetArray{Status: Null} @@ -39,6 +39,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { Status: Present, } } + case []net.IP: if value == nil { *dst = InetArray{Status: Null} @@ -57,6 +58,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { Status: Present, } } + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -81,6 +83,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { } else { *v = nil } + case *[]net.IP: if src.Status == Present { *v = make([]net.IP, len(src.Elements)) @@ -103,29 +106,17 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *InetArray) DecodeText(src []byte) error { + if src == nil { *dst = InetArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Inet if len(uta.Elements) > 0 { @@ -133,8 +124,11 @@ func (dst *InetArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Inet - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -148,19 +142,14 @@ func (dst *InetArray) DecodeText(r io.Reader) error { return nil } -func (dst *InetArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *InetArray) DecodeBinary(src []byte) error { + if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -178,7 +167,14 @@ func (dst *InetArray) DecodeBinary(r io.Reader) error { elements := make([]Inet, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } diff --git a/pgtype/int2.go b/pgtype/int2.go index 8057550b..51346a43 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -88,24 +89,13 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2) DecodeText(src []byte) error { + if src == nil { *dst = Int2{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 16) + n, err := strconv.ParseInt(string(src), 10, 16) if err != nil { return err } @@ -114,27 +104,18 @@ func (dst *Int2) DecodeText(r io.Reader) error { return nil } -func (dst *Int2) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2) DecodeBinary(src []byte) error { + if src == nil { *dst = Int2{Status: Null} return nil } - if size != 2 { - return fmt.Errorf("invalid length for int2: %v", size) + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) } - n, err := pgio.ReadInt16(r) - if err != nil { - return err - } - - *dst = Int2{Int: int16(n), Status: Present} + n := int16(binary.BigEndian.Uint16(src)) + *dst = Int2{Int: n, Status: Present} return nil } diff --git a/pgtype/int2array.go b/pgtype/int2array.go index a989347d..71760e1e 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2Array) DecodeText(src []byte) error { + if src == nil { *dst = Int2Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int2 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int2Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int2 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int2Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int2Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int2Array) DecodeBinary(r io.Reader) error { elements := make([]Int2, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int2Array) EncodeText(w io.Writer) error { } func (src *Int2Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int2OID) +} + +func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int2Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int2OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/int4.go b/pgtype/int4.go index 43691bb6..8a53d454 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -79,24 +80,13 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4) DecodeText(src []byte) error { + if src == nil { *dst = Int4{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 32) + n, err := strconv.ParseInt(string(src), 10, 32) if err != nil { return err } @@ -105,26 +95,17 @@ func (dst *Int4) DecodeText(r io.Reader) error { return nil } -func (dst *Int4) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4) DecodeBinary(src []byte) error { + if src == nil { *dst = Int4{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for int4: %v", size) - } - - n, err := pgio.ReadInt32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) } + n := int32(binary.BigEndian.Uint32(src)) *dst = Int4{Int: n, Status: Present} return nil } diff --git a/pgtype/int4array.go b/pgtype/int4array.go index 89caf263..6a202b08 100644 --- a/pgtype/int4array.go +++ b/pgtype/int4array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4Array) DecodeText(src []byte) error { + if src == nil { *dst = Int4Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int4 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int4Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int4 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int4Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int4Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int4Array) DecodeBinary(r io.Reader) error { elements := make([]Int4, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int4Array) EncodeText(w io.Writer) error { } func (src *Int4Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int4OID) +} + +func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int4Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int4OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/int8.go b/pgtype/int8.go index b87bb85a..c6bedaa6 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -70,24 +71,13 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8) DecodeText(src []byte) error { + if src == nil { *dst = Int8{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 64) + n, err := strconv.ParseInt(string(src), 10, 64) if err != nil { return err } @@ -96,25 +86,17 @@ func (dst *Int8) DecodeText(r io.Reader) error { return nil } -func (dst *Int8) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8) DecodeBinary(src []byte) error { + if src == nil { *dst = Int8{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for int8: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) } - n, err := pgio.ReadInt64(r) - if err != nil { - return err - } + n := int64(binary.BigEndian.Uint64(src)) *dst = Int8{Int: n, Status: Present} return nil diff --git a/pgtype/int8array.go b/pgtype/int8array.go index 003ed055..f621618e 100644 --- a/pgtype/int8array.go +++ b/pgtype/int8array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8Array) DecodeText(src []byte) error { + if src == nil { *dst = Int8Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int8 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int8Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int8 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int8Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int8Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int8Array) DecodeBinary(r io.Reader) error { elements := make([]Int8, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int8Array) EncodeText(w io.Writer) error { } func (src *Int8Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int8OID) +} + +func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int8Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int8OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/name.go b/pgtype/name.go index 3ff81f12..4bbc43c1 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -27,12 +27,12 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(r io.Reader) error { - return (*Text)(dst).DecodeText(r) +func (dst *Name) DecodeText(src []byte) error { + return (*Text)(dst).DecodeText(src) } -func (dst *Name) DecodeBinary(r io.Reader) error { - return (*Text)(dst).DecodeBinary(r) +func (dst *Name) DecodeBinary(src []byte) error { + return (*Text)(dst).DecodeBinary(src) } func (src Name) EncodeText(w io.Writer) error { diff --git a/pgtype/oid.go b/pgtype/oid.go index d137f352..2ea9c2d1 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -24,12 +24,12 @@ func (src *OID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *OID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *OID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *OID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src OID) EncodeText(w io.Writer) error { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 15c0cc76..7928e1cc 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -74,11 +74,11 @@ type Value interface { } type BinaryDecoder interface { - DecodeBinary(r io.Reader) error + DecodeBinary(src []byte) error } type TextDecoder interface { - DecodeText(r io.Reader) error + DecodeText(src []byte) error } type BinaryEncoder interface { diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 9c1ccd6c..9bf1eef6 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "strconv" @@ -51,24 +52,13 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *pguint32) DecodeText(src []byte) error { + if src == nil { *dst = pguint32{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) + n, err := strconv.ParseUint(string(src), 10, 32) if err != nil { return err } @@ -77,26 +67,17 @@ func (dst *pguint32) DecodeText(r io.Reader) error { return nil } -func (dst *pguint32) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *pguint32) DecodeBinary(src []byte) error { + if src == nil { *dst = pguint32{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) } + n := binary.BigEndian.Uint32(src) *dst = pguint32{Uint: n, Status: Present} return nil } diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 6dd14625..8abec935 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -106,27 +106,17 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *QChar) DecodeBinary(src []byte) error { + if src == nil { *dst = QChar{Status: Null} return nil } - if size != 1 { - return fmt.Errorf(`invalid length for "char": %v`, size) + if len(src) != 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = QChar{Int: int8(byt), Status: Present} + *dst = QChar{Int: int8(src[0]), Status: Present} return nil } diff --git a/pgtype/text.go b/pgtype/text.go index c9054468..2951b5ad 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -71,29 +71,18 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Text) DecodeText(src []byte) error { + if src == nil { *dst = Text{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - *dst = Text{String: string(buf), Status: Present} + *dst = Text{String: string(src), Status: Present} return nil } -func (dst *Text) DecodeBinary(r io.Reader) error { - return dst.DecodeText(r) +func (dst *Text) DecodeBinary(src []byte) error { + return dst.DecodeText(src) } func (src Text) EncodeText(w io.Writer) error { diff --git a/pgtype/textarray.go b/pgtype/textarray.go index c420e5c9..e7ca3578 100644 --- a/pgtype/textarray.go +++ b/pgtype/textarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TextArray) DecodeText(src []byte) error { + if src == nil { *dst = TextArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Text if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *TextArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Text - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *TextArray) DecodeText(r io.Reader) error { return nil } -func (dst *TextArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TextArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *TextArray) DecodeBinary(r io.Reader) error { elements := make([]Text, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -211,7 +205,12 @@ func (src *TextArray) EncodeText(w io.Writer) error { } textElementWriter.Reset() - if elem.String == "" && elem.Status == Present { + if elem.Status == Null { + _, err := io.WriteString(buf, `"NULL"`) + if err != nil { + return err + } + } else if elem.String == "" { _, err := io.WriteString(buf, `""`) if err != nil { return err diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index c6933988..ca5eb738 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -72,24 +73,13 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamp) DecodeText(src []byte) error { + if src == nil { *dst = Timestamp{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Timestamp{Status: Present, InfinityModifier: Infinity} @@ -109,25 +99,17 @@ func (dst *Timestamp) DecodeText(r io.Reader) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamp) DecodeBinary(src []byte) error { + if src == nil { *dst = Timestamp{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for timestamp: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) } - microsecSinceY2K, err := pgio.ReadInt64(r) - if err != nil { - return err - } + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) switch microsecSinceY2K { case infinityMicrosecondOffset: diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go index 3acbb35f..695559ac 100644 --- a/pgtype/timestamparray.go +++ b/pgtype/timestamparray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestampArray) DecodeText(src []byte) error { + if src == nil { *dst = TimestampArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Timestamp if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *TimestampArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Timestamp - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *TimestampArray) DecodeText(r io.Reader) error { return nil } -func (dst *TimestampArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestampArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *TimestampArray) DecodeBinary(r io.Reader) error { elements := make([]Timestamp, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *TimestampArray) EncodeText(w io.Writer) error { } func (src *TimestampArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TimestampOID) +} + +func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *TimestampArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = TimestampOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 721c8084..7255bb06 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -71,24 +72,13 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamptz) DecodeText(src []byte) error { + if src == nil { *dst = Timestamptz{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} @@ -115,25 +105,17 @@ func (dst *Timestamptz) DecodeText(r io.Reader) error { return nil } -func (dst *Timestamptz) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamptz) DecodeBinary(src []byte) error { + if src == nil { *dst = Timestamptz{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for timestamptz: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) } - microsecSinceY2K, err := pgio.ReadInt64(r) - if err != nil { - return err - } + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) switch microsecSinceY2K { case infinityMicrosecondOffset: diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go index 9df746e6..ca416c97 100644 --- a/pgtype/timestamptzarray.go +++ b/pgtype/timestamptzarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestamptzArray) DecodeText(src []byte) error { + if src == nil { *dst = TimestamptzArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Timestamptz if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *TimestamptzArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Timestamptz - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *TimestamptzArray) DecodeText(r io.Reader) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestamptzArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { elements := make([]Timestamptz, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) error { } func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TimestamptzOID) +} + +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = TimestamptzOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/to-consider.txt b/pgtype/to-consider.txt new file mode 100644 index 00000000..ba4f3511 --- /dev/null +++ b/pgtype/to-consider.txt @@ -0,0 +1,9 @@ +DecodeText and DecodeBinary take []byte instead of io.Reader +EncodeText and EncodeBinary do not write size +Add Nullable interface with IsNull() and SetNull() + +The above would keep types from needing to worry about writing their own size. Could make EncodeText and DecodeText easier to use with sql.Scanner and driver.Valuer. SetNull() could be removed as DecodeText and DecodeBinary could interpret a nil slice as null. + +EncodeText and EncodeBinary could return (null bool, err error). That would finish removing Nullable interface. + +Also, consider whether arrays and ranges could be represented as generic data types or more common code could be extracted instead of using code generation. diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 8c18073b..316439ef 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -73,29 +73,17 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { + if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []<%= pgtype_element_type %> if len(uta.Elements) > 0 { @@ -103,8 +91,11 @@ func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem <%= pgtype_element_type %> - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +109,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { + if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +134,14 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { elements := make([]<%= pgtype_element_type %>, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp:rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -211,16 +204,9 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } textElementWriter.Reset() - if elem.String == "" && elem.Status == Present { - _, err := io.WriteString(buf, `""`) - if err != nil { - return err - } - } else { - err = elem.EncodeText(textElementWriter) - if err != nil { - return err - } + err = elem.EncodeText(textElementWriter) + if err != nil { + return err } for _, dec := range dimElemCounts { diff --git a/pgtype/varchararray.go b/pgtype/varchararray.go index 13d94bc0..3a5d8536 100644 --- a/pgtype/varchararray.go +++ b/pgtype/varchararray.go @@ -14,12 +14,12 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return (*TextArray)(src).AssignTo(dst) } -func (dst *VarcharArray) DecodeText(r io.Reader) error { - return (*TextArray)(dst).DecodeText(r) +func (dst *VarcharArray) DecodeText(src []byte) error { + return (*TextArray)(dst).DecodeText(src) } -func (dst *VarcharArray) DecodeBinary(r io.Reader) error { - return (*TextArray)(dst).DecodeBinary(r) +func (dst *VarcharArray) DecodeBinary(src []byte) error { + return (*TextArray)(dst).DecodeBinary(src) } func (src *VarcharArray) EncodeText(w io.Writer) error { diff --git a/pgtype/xid.go b/pgtype/xid.go index d4003b5d..389f93bc 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -33,12 +33,12 @@ func (src *XID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *XID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *XID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *XID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *XID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src XID) EncodeText(w io.Writer) error { diff --git a/query.go b/query.go index 965f3913..71d1ba9e 100644 --- a/query.go +++ b/query.go @@ -231,14 +231,12 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { - vr.err = errRewoundLen - err = s.DecodeBinary(&valueReader2{vr}) + err = s.DecodeBinary(vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { - vr.err = errRewoundLen - err = s.DecodeText(&valueReader2{vr}) + err = s.DecodeText(vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } @@ -290,8 +288,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { switch vr.Type().FormatCode { case TextFormatCode: if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { - vr.err = errRewoundLen - err = textDecoder.DecodeText(&valueReader2{vr}) + err = textDecoder.DecodeText(vr.bytes()) if err != nil { vr.Fatal(err) } @@ -300,8 +297,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } case BinaryFormatCode: if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { - vr.err = errRewoundLen - err = binaryDecoder.DecodeBinary(&valueReader2{vr}) + err = binaryDecoder.DecodeBinary(vr.bytes()) if err != nil { vr.Fatal(err) } diff --git a/value_reader.go b/value_reader.go index c91a21af..85932a7d 100644 --- a/value_reader.go +++ b/value_reader.go @@ -4,8 +4,6 @@ import ( "errors" ) -var errRewoundLen = errors.New("len was rewound") - // ValueReader is used by the Scanner interface to decode values. type ValueReader struct { mr *msgReader @@ -157,27 +155,10 @@ func (r *ValueReader) ReadBytes(count int32) []byte { return r.mr.readBytes(count) } -type valueReader2 struct { - *ValueReader -} - -func (r *valueReader2) Read(dst []byte) (int, error) { - if r.err != nil { - return 0, r.err +// bytes is a compatibility function for pgtype.TextDecoder and pgtype.BinaryDecoder +func (r *ValueReader) bytes() []byte { + if r.Len() >= 0 { + return r.ReadBytes(r.Len()) } - - src := r.ReadBytes(int32(len(dst))) - - copy(dst, src) - - return len(dst), nil -} - -func (r *valueReader2) ReadUint32() (uint32, error) { - if r.err == errRewoundLen { - r.err = nil - return uint32(r.Len()), nil - } - - return r.ValueReader.ReadUint32(), nil + return nil } diff --git a/values.go b/values.go index c724aa39..796f2f3d 100644 --- a/values.go +++ b/values.go @@ -3,6 +3,7 @@ package pgx import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/json" "fmt" "io" @@ -455,23 +456,12 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { // in the PostgreSQL sources. OID cannot be NULL. To allow for NULL OIDs use pgtype.OID. type OID uint32 -func (dst *OID) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *OID) DecodeText(src []byte) error { + if src == nil { return fmt.Errorf("cannot decode nil into OID") } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) + n, err := strconv.ParseUint(string(src), 10, 32) if err != nil { return err } @@ -480,25 +470,16 @@ func (dst *OID) DecodeText(r io.Reader) error { return nil } -func (dst *OID) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *OID) DecodeBinary(src []byte) error { + if src == nil { return fmt.Errorf("cannot decode nil into OID") } - if size != 4 { - return fmt.Errorf("invalid length for OID: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) } + n := binary.BigEndian.Uint32(src) *dst = OID(n) return nil } @@ -1020,15 +1001,13 @@ func decodeBool(vr *ValueReader) bool { return false } - vr.err = errRewoundLen - var b pgtype.Bool var err error switch vr.Type().FormatCode { case TextFormatCode: - err = b.DecodeText(&valueReader2{vr}) + err = b.DecodeText(vr.bytes()) case BinaryFormatCode: - err = b.DecodeBinary(&valueReader2{vr}) + err = b.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false @@ -1081,15 +1060,13 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - vr.err = errRewoundLen - var n pgtype.Int8 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.DecodeText(&valueReader2{vr}) + err = n.DecodeText(vr.bytes()) case BinaryFormatCode: - err = n.DecodeBinary(&valueReader2{vr}) + err = n.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 @@ -1115,15 +1092,13 @@ func decodeInt2(vr *ValueReader) int16 { return 0 } - vr.err = errRewoundLen - var n pgtype.Int2 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.DecodeText(&valueReader2{vr}) + err = n.DecodeText(vr.bytes()) case BinaryFormatCode: - err = n.DecodeBinary(&valueReader2{vr}) + err = n.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 @@ -1153,15 +1128,13 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - vr.err = errRewoundLen - var n pgtype.Int4 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.DecodeText(&valueReader2{vr}) + err = n.DecodeText(vr.bytes()) case BinaryFormatCode: - err = n.DecodeBinary(&valueReader2{vr}) + err = n.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 @@ -1455,15 +1428,13 @@ func decodeDate(vr *ValueReader) time.Time { return time.Time{} } - vr.err = errRewoundLen - var d pgtype.Date var err error switch vr.Type().FormatCode { case TextFormatCode: - err = d.DecodeText(&valueReader2{vr}) + err = d.DecodeText(vr.bytes()) case BinaryFormatCode: - err = d.DecodeBinary(&valueReader2{vr}) + err = d.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return time.Time{} @@ -1518,15 +1489,13 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - vr.err = errRewoundLen - var t pgtype.Timestamptz var err error switch vr.Type().FormatCode { case TextFormatCode: - err = t.DecodeText(&valueReader2{vr}) + err = t.DecodeText(vr.bytes()) case BinaryFormatCode: - err = t.DecodeBinary(&valueReader2{vr}) + err = t.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return time.Time{}