MsgReader no longer uses double buffering

scan-io
Jack Christensen 2014-07-04 13:08:37 -05:00
parent 78b8e0b6f2
commit b25aea5c52
4 changed files with 243 additions and 170 deletions

167
conn.go
View File

@ -14,7 +14,6 @@ import (
"fmt"
log "gopkg.in/inconshreveable/log15.v2"
"io"
"io/ioutil"
"net"
"net/url"
"os"
@ -52,8 +51,8 @@ type Conn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
reader *bufio.Reader // buffered reader to improve read performance
wbuf [1024]byte
buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc
bufSize int // desired size of buf
buf *bytes.Buffer
bufSize int
Pid int32 // backend pid
SecretKey int32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server
@ -65,6 +64,7 @@ type Conn struct {
causeOfDeath error
logger log.Logger
qr QueryResult
mr MsgReader
}
type PreparedStatement struct {
@ -197,6 +197,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
}
c.reader = bufio.NewReader(c.conn)
c.mr.reader = c.reader
msg := newStartupMessage()
msg.options["user"] = c.config.User
@ -209,7 +210,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
for {
var t byte
var r *MessageReader
var r *MsgReader
t, r, err = c.rxMsg()
if err != nil {
return nil, err
@ -348,9 +349,9 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
for {
var t byte
var bodySize int32
var r *MsgReader
t, bodySize, err = c.rxMsgHeader()
t, r, err = c.rxMsg()
if err != nil {
return err
}
@ -363,19 +364,12 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
}
if softErr != nil {
c.rxMsgBody(bodySize) // Read and discard rest of message
// Read and discard rest of message
continue
}
softErr = c.rxDataRowValueTo(w, bodySize)
softErr = c.rxDataRowValueTo(w, r)
} else {
var body *bytes.Buffer
body, err = c.rxMsgBody(bodySize)
if err != nil {
return err
}
r := (*MessageReader)(body)
switch t {
case readyForQuery:
c.rxReadyForQuery(r)
@ -392,44 +386,20 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
}
}
func (c *Conn) rxDataRowValueTo(w io.Writer, bodySize int32) (err error) {
b := make([]byte, 2)
_, err = io.ReadFull(c.reader, b)
if err != nil {
c.die(err)
return
}
columnCount := int16(binary.BigEndian.Uint16(b))
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
columnCount := r.ReadInt16()
if columnCount != 1 {
// Read the rest of the data row so it can be discarded
if _, err = io.CopyN(ioutil.Discard, c.reader, int64(bodySize-2)); err != nil {
c.die(err)
return
}
err = UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
return
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
}
b = make([]byte, 4)
_, err = io.ReadFull(c.reader, b)
if err != nil {
c.die(err)
return
}
valueSize := int32(binary.BigEndian.Uint32(b))
valueSize := r.ReadInt32()
if valueSize == -1 {
err = errors.New("SelectValueTo cannot handle null")
return
return errors.New("SelectValueTo cannot handle null")
}
_, err = io.CopyN(w, c.reader, int64(valueSize))
if err != nil {
c.die(err)
}
r.CopyN(w, valueSize)
return
return r.Err()
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
@ -475,7 +445,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
for {
var t byte
var r *MessageReader
var r *MsgReader
t, r, err := c.rxMsg()
if err != nil {
return nil, err
@ -563,7 +533,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
}
var t byte
var r *MessageReader
var r *MsgReader
if t, r, err = c.rxMsg(); err == nil {
if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err
@ -653,12 +623,12 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
}
func (rr *RowReader) ReadString(qr *QueryResult) string {
_, size, ok := qr.NextColumn()
fd, size, ok := qr.NextColumn()
if !ok {
return ""
}
return qr.mr.ReadString(size)
return decodeText(qr, fd, size)
}
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
@ -671,7 +641,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
return vt.Decode(qr, fd, size)
} else {
return qr.mr.ReadString(size)
return decodeText(qr, fd, size)
}
} else {
return nil
@ -681,7 +651,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
type QueryResult struct {
pool *ConnPool
conn *Conn
mr *MessageReader
mr *MsgReader
fields []FieldDescription
rowCount int
columnIdx int
@ -693,7 +663,7 @@ func (qr *QueryResult) FieldDescriptions() []FieldDescription {
return qr.fields
}
func (qr *QueryResult) MessageReader() *MessageReader {
func (qr *QueryResult) MsgReader() *MsgReader {
return qr.mr
}
@ -771,8 +741,8 @@ func (qr *QueryResult) NextRow() bool {
qr.close()
return false
case dataRow:
fieldCount := int(r.ReadInt16())
if fieldCount != len(qr.fields) {
fieldCount := r.ReadInt16()
if int(fieldCount) != len(qr.fields) {
qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount)))
return false
}
@ -803,7 +773,6 @@ func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
fd := &qr.fields[qr.columnIdx]
qr.columnIdx++
size := qr.mr.ReadInt32()
return fd, size, true
}
@ -965,7 +934,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
for {
var t byte
var r *MessageReader
var r *MsgReader
t, r, err = c.rxMsg()
if err != nil {
return commandTag, err
@ -1040,7 +1009,7 @@ func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages
// is the same regardless of when they occur.
func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
switch t {
case 'S':
c.rxParameterStatus(r)
@ -1050,69 +1019,28 @@ func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
case noticeResponse:
return nil
case notificationResponse:
return c.rxNotificationResponse(r)
c.rxNotificationResponse(r)
return nil
default:
return fmt.Errorf("Received unknown message type: %c", t)
}
}
func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) {
var bodySize int32
t, bodySize, err = c.rxMsgHeader()
if err != nil {
return
}
var body *bytes.Buffer
if body, err = c.rxMsgBody(bodySize); err != nil {
return
}
r = (*MessageReader)(body)
return
}
func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) {
func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
if !c.alive {
return 0, 0, DeadConnError
return 0, nil, DeadConnError
}
t, err = c.reader.ReadByte()
t, err = c.mr.rxMsg()
if err != nil {
c.die(err)
return 0, 0, err
}
b := make([]byte, 4)
_, err = io.ReadFull(c.reader, b)
if err != nil {
c.die(err)
return 0, 0, err
}
bodySize = int32(binary.BigEndian.Uint32(b))
bodySize -= 4 // remove self from size
return t, bodySize, err
return t, &c.mr, err
}
func (c *Conn) rxMsgBody(bodySize int32) (*bytes.Buffer, error) {
if !c.alive {
return nil, DeadConnError
}
buf := c.getBuf()
_, err := io.CopyN(buf, c.reader, int64(bodySize))
if err != nil {
c.die(err)
return nil, err
}
return buf, nil
}
func (c *Conn) rxAuthenticationX(r *MessageReader) (err error) {
code := r.ReadInt32()
switch code {
func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) {
switch r.ReadInt32() {
case 0: // AuthenticationOk
case 3: // AuthenticationCleartextPassword
err = c.txPasswordMessage(c.config.Password)
@ -1133,13 +1061,13 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
func (c *Conn) rxParameterStatus(r *MessageReader) {
func (c *Conn) rxParameterStatus(r *MsgReader) {
key := r.ReadCString()
value := r.ReadCString()
c.RuntimeParams[key] = value
}
func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) {
for {
switch r.ReadByte() {
case 'S':
@ -1159,16 +1087,16 @@ func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
}
}
func (c *Conn) rxBackendKeyData(r *MessageReader) {
func (c *Conn) rxBackendKeyData(r *MsgReader) {
c.Pid = r.ReadInt32()
c.SecretKey = r.ReadInt32()
}
func (c *Conn) rxReadyForQuery(r *MessageReader) {
func (c *Conn) rxReadyForQuery(r *MsgReader) {
c.TxStatus = r.ReadByte()
}
func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) {
fieldCount := r.ReadInt16()
fields = make([]FieldDescription, fieldCount)
for i := int16(0); i < fieldCount; i++ {
@ -1184,26 +1112,26 @@ func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
return
}
func (c *Conn) rxParameterDescription(r *MessageReader) (parameters []Oid) {
func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) {
parameterCount := r.ReadInt16()
parameters = make([]Oid, 0, parameterCount)
for i := int16(0); i < parameterCount; i++ {
parameters = append(parameters, r.ReadOid())
}
return
}
func (c *Conn) rxCommandComplete(r *MessageReader) string {
func (c *Conn) rxCommandComplete(r *MsgReader) string {
return r.ReadCString()
}
func (c *Conn) rxNotificationResponse(r *MessageReader) (err error) {
func (c *Conn) rxNotificationResponse(r *MsgReader) {
n := new(Notification)
n.Pid = r.ReadInt32()
n.Channel = r.ReadCString()
n.Payload = r.ReadCString()
c.notifications = append(c.notifications, n)
return
}
func (c *Conn) startTLS() (err error) {
@ -1280,12 +1208,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
// its internal byte array, check on the size and create a new bytes.Buffer so the
// old one can get GC'ed
func (c *Conn) getBuf() *bytes.Buffer {
c.buf.Reset()
if cap(c.buf.Bytes()) > c.bufSize {
c.logger.Debug(fmt.Sprintf("c.buf (%d) is larger than c.bufSize (%d) -- resetting", cap(c.buf.Bytes()), c.bufSize))
c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize))
}
return c.buf
return &bytes.Buffer{}
}
func (c *Conn) die(err error) {

View File

@ -51,7 +51,7 @@ func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Poin
switch fd.FormatCode {
case pgx.TextFormatCode:
s := qr.MessageReader().ReadString(size)
s := qr.MsgReader().ReadString(size)
match := pointRegexp.FindStringSubmatch(s)
if match == nil {
qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)))

View File

@ -1,47 +0,0 @@
package pgx
import (
"bytes"
"encoding/binary"
)
// MessageReader is a helper that reads values from a PostgreSQL message.
type MessageReader bytes.Buffer
func (r *MessageReader) ReadByte() (b byte) {
b, _ = (*bytes.Buffer)(r).ReadByte()
return
}
func (r *MessageReader) ReadInt16() (n int16) {
b := (*bytes.Buffer)(r).Next(2)
return int16(binary.BigEndian.Uint16(b))
}
func (r *MessageReader) ReadInt32() (n int32) {
b := (*bytes.Buffer)(r).Next(4)
return int32(binary.BigEndian.Uint32(b))
}
func (r *MessageReader) ReadInt64() (n int64) {
b := (*bytes.Buffer)(r).Next(8)
return int64(binary.BigEndian.Uint64(b))
}
func (r *MessageReader) ReadOid() (oid Oid) {
b := (*bytes.Buffer)(r).Next(4)
return Oid(binary.BigEndian.Uint32(b))
}
// ReadString reads a null terminated string
func (r *MessageReader) ReadCString() (s string) {
b, _ := (*bytes.Buffer)(r).ReadBytes(0)
return string(b[:len(b)-1])
}
// ReadString reads count bytes and returns as string
func (r *MessageReader) ReadString(count int32) (s string) {
size := int(count)
b := (*bytes.Buffer)(r).Next(size)
return string(b)
}

197
msg_reader.go Normal file
View File

@ -0,0 +1,197 @@
package pgx
import (
"bufio"
"encoding/binary"
"errors"
"io"
"io/ioutil"
)
// MsgReader is a helper that reads values from a PostgreSQL message.
type MsgReader struct {
reader *bufio.Reader
buf [128]byte
msgBytesRemaining int32
err error
}
// Err returns any error that the MsgReader has experienced
func (r *MsgReader) Err() error {
return r.err
}
// Fatal tells r that a Fatal error has occurred
func (r *MsgReader) Fatal(err error) {
r.err = err
}
// rxMsg reads the type and size of the next message.
func (r *MsgReader) rxMsg() (t byte, err error) {
if r.err != nil {
return 0, err
}
if r.msgBytesRemaining > 0 {
io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining))
}
t, err = r.reader.ReadByte()
b := r.buf[0:4]
_, err = io.ReadFull(r.reader, b)
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b)) - 4
return t, err
}
func (r *MsgReader) ReadByte() byte {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 1
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.ReadByte()
if err != nil {
r.Fatal(err)
return 0
}
return b
}
func (r *MsgReader) ReadInt16() int16 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b := r.buf[0:2]
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return 0
}
return int16(binary.BigEndian.Uint16(b))
}
func (r *MsgReader) ReadInt32() int32 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b := r.buf[0:4]
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return 0
}
return int32(binary.BigEndian.Uint32(b))
}
func (r *MsgReader) ReadInt64() int64 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 8
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b := r.buf[0:8]
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return 0
}
return int64(binary.BigEndian.Uint64(b))
}
func (r *MsgReader) ReadOid() Oid {
return Oid(r.ReadInt32())
}
// ReadCString reads a null terminated string
func (r *MsgReader) ReadCString() string {
if r.err != nil {
return ""
}
b, err := r.reader.ReadBytes(0)
if err != nil {
r.Fatal(err)
return ""
}
r.msgBytesRemaining -= int32(len(b))
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return ""
}
return string(b[0 : len(b)-1])
}
// ReadString reads count bytes and returns as string
func (r *MsgReader) ReadString(count int32) string {
if r.err != nil {
return ""
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return ""
}
var b []byte
if count <= int32(len(r.buf)) {
b = r.buf[0:int(count)]
} else {
b = make([]byte, int(count))
}
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return ""
}
return string(b)
}
func (r *MsgReader) CopyN(w io.Writer, count int32) {
if r.err != nil {
return
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return
}
_, err := io.CopyN(w, r.reader, int64(count))
if err != nil {
r.Fatal(err)
}
}