diff --git a/message_reader.go b/message_reader.go index 85c18391..cc1eb87e 100644 --- a/message_reader.go +++ b/message_reader.go @@ -3,51 +3,113 @@ package pgx import ( "bytes" "encoding/binary" + "fmt" ) // MessageReader is a helper that reads values from a PostgreSQL message. +// To avoid verbose error handling it internally records errors and no-ops +// any calls that occur after an error. At the end of a sequence of reads +// the Err field should be checked to see if any errors occurred. type MessageReader struct { buf *bytes.Buffer + Err error } func newMessageReader(buf *bytes.Buffer) *MessageReader { return &MessageReader{buf: buf} } -func (r *MessageReader) ReadByte() byte { - b, err := r.buf.ReadByte() - if err != nil { - panic("Unable to read byte") +func (r *MessageReader) ReadByte() (b byte) { + if r.Err != nil { + return } - return b + + b, r.Err = r.buf.ReadByte() + return } -func (r *MessageReader) ReadInt16() int16 { - return int16(binary.BigEndian.Uint16(r.buf.Next(2))) +func (r *MessageReader) ReadInt16() (n int16) { + if r.Err != nil { + return + } + + size := 2 + b := r.buf.Next(size) + if len(b) != size { + r.Err = fmt.Errorf("Unable to read %d bytes, only read %d", size, len(b)) + } + + return int16(binary.BigEndian.Uint16(b)) } -func (r *MessageReader) ReadInt32() int32 { - return int32(binary.BigEndian.Uint32(r.buf.Next(4))) +func (r *MessageReader) ReadInt32() (n int32) { + if r.Err != nil { + return + } + + size := 4 + b := r.buf.Next(size) + if len(b) != size { + r.Err = fmt.Errorf("Unable to read %d bytes, only read %d", size, len(b)) + } + + return int32(binary.BigEndian.Uint32(b)) } -func (r *MessageReader) ReadInt64() int64 { - return int64(binary.BigEndian.Uint64(r.buf.Next(8))) +func (r *MessageReader) ReadInt64() (n int64) { + if r.Err != nil { + return + } + + size := 8 + b := r.buf.Next(size) + if len(b) != size { + r.Err = fmt.Errorf("Unable to read %d bytes, only read %d", size, len(b)) + } + + return int64(binary.BigEndian.Uint64(b)) } -func (r *MessageReader) ReadOid() Oid { - return Oid(binary.BigEndian.Uint32(r.buf.Next(4))) +func (r *MessageReader) ReadOid() (oid Oid) { + if r.Err != nil { + return + } + + size := 4 + b := r.buf.Next(size) + if len(b) != size { + r.Err = fmt.Errorf("Unable to read %d bytes, only read %d", size, len(b)) + } + + return Oid(binary.BigEndian.Uint32(b)) } // ReadString reads a null terminated string -func (r *MessageReader) ReadString() string { - b, err := r.buf.ReadBytes(0) - if err != nil { - panic("Unable to read string") +func (r *MessageReader) ReadString() (s string) { + if r.Err != nil { + return } + + var b []byte + b, r.Err = r.buf.ReadBytes(0) + if r.Err != nil { + return + } + return string(b[:len(b)-1]) } // ReadByteString reads count bytes and return as string -func (r *MessageReader) ReadByteString(count int32) string { - return string(r.buf.Next(int(count))) +func (r *MessageReader) ReadByteString(count int32) (s string) { + if r.Err != nil { + return + } + + size := int(count) + b := r.buf.Next(size) + if len(b) != size { + r.Err = fmt.Errorf("Unable to read %d bytes, only read %d", size, len(b)) + } + + return string(b) }