diff --git a/messages.go b/messages.go index 2320bf7f..f6be9ff9 100644 --- a/messages.go +++ b/messages.go @@ -101,6 +101,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf { // by the Encoder interface when implementing custom encoders. type WriteBuf struct { buf []byte + convBuf [8]byte sizeIdx int conn *Conn } @@ -125,41 +126,32 @@ func (wb *WriteBuf) WriteCString(s string) { } func (wb *WriteBuf) WriteInt16(n int16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, uint16(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint16(uint16(n)) } func (wb *WriteBuf) WriteUint16(n uint16) (int, error) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, n) - wb.buf = append(wb.buf, b...) + binary.BigEndian.PutUint16(wb.convBuf[:2], n) + wb.buf = append(wb.buf, wb.convBuf[:2]...) return 2, nil } func (wb *WriteBuf) WriteInt32(n int32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, uint32(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint32(uint32(n)) } func (wb *WriteBuf) WriteUint32(n uint32) (int, error) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, n) - wb.buf = append(wb.buf, b...) + binary.BigEndian.PutUint32(wb.convBuf[:4], n) + wb.buf = append(wb.buf, wb.convBuf[:4]...) return 4, nil } func (wb *WriteBuf) WriteInt64(n int64) { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint64(uint64(n)) } func (wb *WriteBuf) WriteUint64(n uint64) (int, error) { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, n) - wb.buf = append(wb.buf, b...) + binary.BigEndian.PutUint64(wb.convBuf[:8], n) + wb.buf = append(wb.buf, wb.convBuf[:8]...) return 8, nil } diff --git a/pgtype/int4.go b/pgtype/int4.go index ab993476..29b5dd1b 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -1,7 +1,6 @@ package pgtype import ( - "encoding/binary" "fmt" "io" "strconv" @@ -9,8 +8,23 @@ import ( type Int4 int32 -func (i *Int4) ParseText(src string) error { - n, err := strconv.ParseInt(src, 10, 32) +func (i *Int4) DecodeText(r io.Reader) error { + size, err := ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + return fmt.Errorf("invalid length for int4: %v", size) + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 32) if err != nil { return err } @@ -19,12 +33,22 @@ func (i *Int4) ParseText(src string) error { return nil } -func (i *Int4) ParseBinary(src []byte) error { - if len(src) != 4 { - return fmt.Errorf("invalid length for int4: %v", len(src)) +func (i *Int4) DecodeBinary(r io.Reader) error { + size, err := ReadInt32(r) + if err != nil { + return err } - *i = Int4(binary.BigEndian.Uint32(src)) + if size != 4 { + return fmt.Errorf("invalid length for int4: %v", size) + } + + n, err := ReadInt32(r) + if err != nil { + return err + } + + *i = Int4(n) return nil } diff --git a/pgtype/typed_reader.go b/pgtype/typed_reader.go new file mode 100644 index 00000000..29997338 --- /dev/null +++ b/pgtype/typed_reader.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "encoding/binary" + "io" +) + +type uint16Reader interface { + ReadUint16() (n uint16, err error) +} + +type uint32Reader interface { + ReadUint32() (n uint32, err error) +} + +type uint64Reader interface { + ReadUint64() (n uint64, err error) +} + +// ReadByte reads a byte from r. +func ReadByte(r io.Reader) (byte, error) { + if r, ok := r.(io.ByteReader); ok { + return r.ReadByte() + } + + buf := make([]byte, 1) + _, err := r.Read(buf) + return buf[0], err +} + +// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadUint16(r io.Reader) (uint16, error) { + if r, ok := r.(uint16Reader); ok { + return r.ReadUint16() + } + + buf := make([]byte, 2) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint16(buf), nil +} + +// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadInt16(r io.Reader) (int16, error) { + n, err := ReadUint16(r) + return int16(n), err +} + +// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadUint32(r io.Reader) (uint32, error) { + if r, ok := r.(uint32Reader); ok { + return r.ReadUint32() + } + + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(buf), nil +} + +// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadInt32(r io.Reader) (int32, error) { + n, err := ReadUint32(r) + return int32(n), err +} + +// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadUint64(r io.Reader) (uint64, error) { + if r, ok := r.(uint64Reader); ok { + return r.ReadUint64() + } + + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(buf), nil +} + +// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadInt64(r io.Reader) (int64, error) { + n, err := ReadUint64(r) + return int64(n), err +} diff --git a/value_reader.go b/value_reader.go index 249b8ba3..c91a21af 100644 --- a/value_reader.go +++ b/value_reader.go @@ -4,6 +4,8 @@ import ( "errors" ) +var errRewoundLen = errors.New("len was rewound") + // ValueReader is used by the Scanner interface to decode values. type ValueReader struct { mr *msgReader @@ -154,3 +156,28 @@ func (r *ValueReader) ReadBytes(count int32) []byte { return r.mr.readBytes(count) } + +type valueReader2 struct { + *ValueReader +} + +func (r *valueReader2) Read(dst []byte) (int, error) { + if r.err != nil { + return 0, r.err + } + + src := r.ReadBytes(int32(len(dst))) + + copy(dst, src) + + return len(dst), nil +} + +func (r *valueReader2) ReadUint32() (uint32, error) { + if r.err == errRewoundLen { + r.err = nil + return uint32(r.Len()), nil + } + + return r.ValueReader.ReadUint32(), nil +} diff --git a/values.go b/values.go index ee53edfc..85c7ad3d 100644 --- a/values.go +++ b/values.go @@ -1789,13 +1789,15 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } + vr.err = errRewoundLen + var n pgtype.Int4 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.ParseText(vr.ReadString(vr.Len())) + err = n.DecodeText(&valueReader2{vr}) case BinaryFormatCode: - err = n.ParseBinary(vr.ReadBytes(vr.Len())) + err = n.DecodeBinary(&valueReader2{vr}) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0