mirror of https://github.com/jackc/pgx.git
MsgReader no longer uses double buffering
parent
78b8e0b6f2
commit
b25aea5c52
167
conn.go
167
conn.go
|
@ -14,7 +14,6 @@ import (
|
|||
"fmt"
|
||||
log "gopkg.in/inconshreveable/log15.v2"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -52,8 +51,8 @@ type Conn struct {
|
|||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
reader *bufio.Reader // buffered reader to improve read performance
|
||||
wbuf [1024]byte
|
||||
buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc
|
||||
bufSize int // desired size of buf
|
||||
buf *bytes.Buffer
|
||||
bufSize int
|
||||
Pid int32 // backend pid
|
||||
SecretKey int32 // key to use to send a cancel query message to the server
|
||||
RuntimeParams map[string]string // parameters that have been reported by the server
|
||||
|
@ -65,6 +64,7 @@ type Conn struct {
|
|||
causeOfDeath error
|
||||
logger log.Logger
|
||||
qr QueryResult
|
||||
mr MsgReader
|
||||
}
|
||||
|
||||
type PreparedStatement struct {
|
||||
|
@ -197,6 +197,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
}
|
||||
|
||||
c.reader = bufio.NewReader(c.conn)
|
||||
c.mr.reader = c.reader
|
||||
|
||||
msg := newStartupMessage()
|
||||
msg.options["user"] = c.config.User
|
||||
|
@ -209,7 +210,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *MessageReader
|
||||
var r *MsgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -348,9 +349,9 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var bodySize int32
|
||||
var r *MsgReader
|
||||
|
||||
t, bodySize, err = c.rxMsgHeader()
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -363,19 +364,12 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
|||
}
|
||||
|
||||
if softErr != nil {
|
||||
c.rxMsgBody(bodySize) // Read and discard rest of message
|
||||
// Read and discard rest of message
|
||||
continue
|
||||
}
|
||||
|
||||
softErr = c.rxDataRowValueTo(w, bodySize)
|
||||
softErr = c.rxDataRowValueTo(w, r)
|
||||
} else {
|
||||
var body *bytes.Buffer
|
||||
body, err = c.rxMsgBody(bodySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := (*MessageReader)(body)
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
|
@ -392,44 +386,20 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rxDataRowValueTo(w io.Writer, bodySize int32) (err error) {
|
||||
b := make([]byte, 2)
|
||||
_, err = io.ReadFull(c.reader, b)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return
|
||||
}
|
||||
columnCount := int16(binary.BigEndian.Uint16(b))
|
||||
|
||||
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
|
||||
columnCount := r.ReadInt16()
|
||||
if columnCount != 1 {
|
||||
// Read the rest of the data row so it can be discarded
|
||||
if _, err = io.CopyN(ioutil.Discard, c.reader, int64(bodySize-2)); err != nil {
|
||||
c.die(err)
|
||||
return
|
||||
}
|
||||
err = UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
|
||||
return
|
||||
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
|
||||
}
|
||||
|
||||
b = make([]byte, 4)
|
||||
_, err = io.ReadFull(c.reader, b)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return
|
||||
}
|
||||
valueSize := int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
valueSize := r.ReadInt32()
|
||||
if valueSize == -1 {
|
||||
err = errors.New("SelectValueTo cannot handle null")
|
||||
return
|
||||
return errors.New("SelectValueTo cannot handle null")
|
||||
}
|
||||
|
||||
_, err = io.CopyN(w, c.reader, int64(valueSize))
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
}
|
||||
r.CopyN(w, valueSize)
|
||||
|
||||
return
|
||||
return r.Err()
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
||||
|
@ -475,7 +445,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *MessageReader
|
||||
var r *MsgReader
|
||||
t, r, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -563,7 +533,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||
}
|
||||
|
||||
var t byte
|
||||
var r *MessageReader
|
||||
var r *MsgReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return nil, err
|
||||
|
@ -653,12 +623,12 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
|||
}
|
||||
|
||||
func (rr *RowReader) ReadString(qr *QueryResult) string {
|
||||
_, size, ok := qr.NextColumn()
|
||||
fd, size, ok := qr.NextColumn()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return qr.mr.ReadString(size)
|
||||
return decodeText(qr, fd, size)
|
||||
}
|
||||
|
||||
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
||||
|
@ -671,7 +641,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
|||
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
|
||||
return vt.Decode(qr, fd, size)
|
||||
} else {
|
||||
return qr.mr.ReadString(size)
|
||||
return decodeText(qr, fd, size)
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
|
@ -681,7 +651,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
|||
type QueryResult struct {
|
||||
pool *ConnPool
|
||||
conn *Conn
|
||||
mr *MessageReader
|
||||
mr *MsgReader
|
||||
fields []FieldDescription
|
||||
rowCount int
|
||||
columnIdx int
|
||||
|
@ -693,7 +663,7 @@ func (qr *QueryResult) FieldDescriptions() []FieldDescription {
|
|||
return qr.fields
|
||||
}
|
||||
|
||||
func (qr *QueryResult) MessageReader() *MessageReader {
|
||||
func (qr *QueryResult) MsgReader() *MsgReader {
|
||||
return qr.mr
|
||||
}
|
||||
|
||||
|
@ -771,8 +741,8 @@ func (qr *QueryResult) NextRow() bool {
|
|||
qr.close()
|
||||
return false
|
||||
case dataRow:
|
||||
fieldCount := int(r.ReadInt16())
|
||||
if fieldCount != len(qr.fields) {
|
||||
fieldCount := r.ReadInt16()
|
||||
if int(fieldCount) != len(qr.fields) {
|
||||
qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount)))
|
||||
return false
|
||||
}
|
||||
|
@ -803,7 +773,6 @@ func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
|
|||
fd := &qr.fields[qr.columnIdx]
|
||||
qr.columnIdx++
|
||||
size := qr.mr.ReadInt32()
|
||||
|
||||
return fd, size, true
|
||||
}
|
||||
|
||||
|
@ -965,7 +934,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *MessageReader
|
||||
var r *MsgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return commandTag, err
|
||||
|
@ -1040,7 +1009,7 @@ func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err
|
|||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages
|
||||
// is the same regardless of when they occur.
|
||||
func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
|
||||
func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
|
||||
switch t {
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
|
@ -1050,69 +1019,28 @@ func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
|
|||
case noticeResponse:
|
||||
return nil
|
||||
case notificationResponse:
|
||||
return c.rxNotificationResponse(r)
|
||||
c.rxNotificationResponse(r)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Received unknown message type: %c", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) {
|
||||
var bodySize int32
|
||||
t, bodySize, err = c.rxMsgHeader()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var body *bytes.Buffer
|
||||
if body, err = c.rxMsgBody(bodySize); err != nil {
|
||||
return
|
||||
}
|
||||
r = (*MessageReader)(body)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) {
|
||||
func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
|
||||
if !c.alive {
|
||||
return 0, 0, DeadConnError
|
||||
return 0, nil, DeadConnError
|
||||
}
|
||||
|
||||
t, err = c.reader.ReadByte()
|
||||
t, err = c.mr.rxMsg()
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
b := make([]byte, 4)
|
||||
_, err = io.ReadFull(c.reader, b)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return 0, 0, err
|
||||
return t, &c.mr, err
|
||||
}
|
||||
|
||||
bodySize = int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
bodySize -= 4 // remove self from size
|
||||
return t, bodySize, err
|
||||
}
|
||||
|
||||
func (c *Conn) rxMsgBody(bodySize int32) (*bytes.Buffer, error) {
|
||||
if !c.alive {
|
||||
return nil, DeadConnError
|
||||
}
|
||||
|
||||
buf := c.getBuf()
|
||||
_, err := io.CopyN(buf, c.reader, int64(bodySize))
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *Conn) rxAuthenticationX(r *MessageReader) (err error) {
|
||||
code := r.ReadInt32()
|
||||
switch code {
|
||||
func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) {
|
||||
switch r.ReadInt32() {
|
||||
case 0: // AuthenticationOk
|
||||
case 3: // AuthenticationCleartextPassword
|
||||
err = c.txPasswordMessage(c.config.Password)
|
||||
|
@ -1133,13 +1061,13 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (c *Conn) rxParameterStatus(r *MessageReader) {
|
||||
func (c *Conn) rxParameterStatus(r *MsgReader) {
|
||||
key := r.ReadCString()
|
||||
value := r.ReadCString()
|
||||
c.RuntimeParams[key] = value
|
||||
}
|
||||
|
||||
func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
|
||||
func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) {
|
||||
for {
|
||||
switch r.ReadByte() {
|
||||
case 'S':
|
||||
|
@ -1159,16 +1087,16 @@ func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rxBackendKeyData(r *MessageReader) {
|
||||
func (c *Conn) rxBackendKeyData(r *MsgReader) {
|
||||
c.Pid = r.ReadInt32()
|
||||
c.SecretKey = r.ReadInt32()
|
||||
}
|
||||
|
||||
func (c *Conn) rxReadyForQuery(r *MessageReader) {
|
||||
func (c *Conn) rxReadyForQuery(r *MsgReader) {
|
||||
c.TxStatus = r.ReadByte()
|
||||
}
|
||||
|
||||
func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
|
||||
func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) {
|
||||
fieldCount := r.ReadInt16()
|
||||
fields = make([]FieldDescription, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
|
@ -1184,26 +1112,26 @@ func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Conn) rxParameterDescription(r *MessageReader) (parameters []Oid) {
|
||||
func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) {
|
||||
parameterCount := r.ReadInt16()
|
||||
parameters = make([]Oid, 0, parameterCount)
|
||||
|
||||
for i := int16(0); i < parameterCount; i++ {
|
||||
parameters = append(parameters, r.ReadOid())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) rxCommandComplete(r *MessageReader) string {
|
||||
func (c *Conn) rxCommandComplete(r *MsgReader) string {
|
||||
return r.ReadCString()
|
||||
}
|
||||
|
||||
func (c *Conn) rxNotificationResponse(r *MessageReader) (err error) {
|
||||
func (c *Conn) rxNotificationResponse(r *MsgReader) {
|
||||
n := new(Notification)
|
||||
n.Pid = r.ReadInt32()
|
||||
n.Channel = r.ReadCString()
|
||||
n.Payload = r.ReadCString()
|
||||
c.notifications = append(c.notifications, n)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) startTLS() (err error) {
|
||||
|
@ -1280,12 +1208,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
|
|||
// its internal byte array, check on the size and create a new bytes.Buffer so the
|
||||
// old one can get GC'ed
|
||||
func (c *Conn) getBuf() *bytes.Buffer {
|
||||
c.buf.Reset()
|
||||
if cap(c.buf.Bytes()) > c.bufSize {
|
||||
c.logger.Debug(fmt.Sprintf("c.buf (%d) is larger than c.bufSize (%d) -- resetting", cap(c.buf.Bytes()), c.bufSize))
|
||||
c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize))
|
||||
}
|
||||
return c.buf
|
||||
return &bytes.Buffer{}
|
||||
}
|
||||
|
||||
func (c *Conn) die(err error) {
|
||||
|
|
|
@ -51,7 +51,7 @@ func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Poin
|
|||
|
||||
switch fd.FormatCode {
|
||||
case pgx.TextFormatCode:
|
||||
s := qr.MessageReader().ReadString(size)
|
||||
s := qr.MsgReader().ReadString(size)
|
||||
match := pointRegexp.FindStringSubmatch(s)
|
||||
if match == nil {
|
||||
qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)))
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// MessageReader is a helper that reads values from a PostgreSQL message.
|
||||
type MessageReader bytes.Buffer
|
||||
|
||||
func (r *MessageReader) ReadByte() (b byte) {
|
||||
b, _ = (*bytes.Buffer)(r).ReadByte()
|
||||
return
|
||||
}
|
||||
|
||||
func (r *MessageReader) ReadInt16() (n int16) {
|
||||
b := (*bytes.Buffer)(r).Next(2)
|
||||
return int16(binary.BigEndian.Uint16(b))
|
||||
}
|
||||
|
||||
func (r *MessageReader) ReadInt32() (n int32) {
|
||||
b := (*bytes.Buffer)(r).Next(4)
|
||||
return int32(binary.BigEndian.Uint32(b))
|
||||
}
|
||||
|
||||
func (r *MessageReader) ReadInt64() (n int64) {
|
||||
b := (*bytes.Buffer)(r).Next(8)
|
||||
return int64(binary.BigEndian.Uint64(b))
|
||||
}
|
||||
|
||||
func (r *MessageReader) ReadOid() (oid Oid) {
|
||||
b := (*bytes.Buffer)(r).Next(4)
|
||||
return Oid(binary.BigEndian.Uint32(b))
|
||||
}
|
||||
|
||||
// ReadString reads a null terminated string
|
||||
func (r *MessageReader) ReadCString() (s string) {
|
||||
b, _ := (*bytes.Buffer)(r).ReadBytes(0)
|
||||
return string(b[:len(b)-1])
|
||||
}
|
||||
|
||||
// ReadString reads count bytes and returns as string
|
||||
func (r *MessageReader) ReadString(count int32) (s string) {
|
||||
size := int(count)
|
||||
b := (*bytes.Buffer)(r).Next(size)
|
||||
return string(b)
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// MsgReader is a helper that reads values from a PostgreSQL message.
|
||||
type MsgReader struct {
|
||||
reader *bufio.Reader
|
||||
buf [128]byte
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
}
|
||||
|
||||
// Err returns any error that the MsgReader has experienced
|
||||
func (r *MsgReader) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// Fatal tells r that a Fatal error has occurred
|
||||
func (r *MsgReader) Fatal(err error) {
|
||||
r.err = err
|
||||
}
|
||||
|
||||
// rxMsg reads the type and size of the next message.
|
||||
func (r *MsgReader) rxMsg() (t byte, err error) {
|
||||
if r.err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if r.msgBytesRemaining > 0 {
|
||||
io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining))
|
||||
}
|
||||
|
||||
t, err = r.reader.ReadByte()
|
||||
b := r.buf[0:4]
|
||||
_, err = io.ReadFull(r.reader, b)
|
||||
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b)) - 4
|
||||
return t, err
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadByte() byte {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 1
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadInt16() int16 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:2]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return int16(binary.BigEndian.Uint16(b))
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadInt32() int32 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:4]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return int32(binary.BigEndian.Uint32(b))
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadInt64() int64 {
|
||||
if r.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 8
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b := r.buf[0:8]
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return int64(binary.BigEndian.Uint64(b))
|
||||
}
|
||||
|
||||
func (r *MsgReader) ReadOid() Oid {
|
||||
return Oid(r.ReadInt32())
|
||||
}
|
||||
|
||||
// ReadCString reads a null terminated string
|
||||
func (r *MsgReader) ReadCString() string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= int32(len(b))
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(b[0 : len(b)-1])
|
||||
}
|
||||
|
||||
// ReadString reads count bytes and returns as string
|
||||
func (r *MsgReader) ReadString(count int32) string {
|
||||
if r.err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
var b []byte
|
||||
if count <= int32(len(r.buf)) {
|
||||
b = r.buf[0:int(count)]
|
||||
} else {
|
||||
b = make([]byte, int(count))
|
||||
}
|
||||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (r *MsgReader) CopyN(w io.Writer, count int32) {
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.Fatal(errors.New("read past end of message"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err := io.CopyN(w, r.reader, int64(count))
|
||||
if err != nil {
|
||||
r.Fatal(err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue