MsgReader no longer uses double buffering

scan-io
Jack Christensen 2014-07-04 13:08:37 -05:00
parent 78b8e0b6f2
commit b25aea5c52
4 changed files with 243 additions and 170 deletions

167
conn.go
View File

@ -14,7 +14,6 @@ import (
"fmt" "fmt"
log "gopkg.in/inconshreveable/log15.v2" log "gopkg.in/inconshreveable/log15.v2"
"io" "io"
"io/ioutil"
"net" "net"
"net/url" "net/url"
"os" "os"
@ -52,8 +51,8 @@ type Conn struct {
conn net.Conn // the underlying TCP or unix domain socket connection conn net.Conn // the underlying TCP or unix domain socket connection
reader *bufio.Reader // buffered reader to improve read performance reader *bufio.Reader // buffered reader to improve read performance
wbuf [1024]byte wbuf [1024]byte
buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc buf *bytes.Buffer
bufSize int // desired size of buf bufSize int
Pid int32 // backend pid Pid int32 // backend pid
SecretKey int32 // key to use to send a cancel query message to the server SecretKey int32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server RuntimeParams map[string]string // parameters that have been reported by the server
@ -65,6 +64,7 @@ type Conn struct {
causeOfDeath error causeOfDeath error
logger log.Logger logger log.Logger
qr QueryResult qr QueryResult
mr MsgReader
} }
type PreparedStatement struct { type PreparedStatement struct {
@ -197,6 +197,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
} }
c.reader = bufio.NewReader(c.conn) c.reader = bufio.NewReader(c.conn)
c.mr.reader = c.reader
msg := newStartupMessage() msg := newStartupMessage()
msg.options["user"] = c.config.User msg.options["user"] = c.config.User
@ -209,7 +210,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
for { for {
var t byte var t byte
var r *MessageReader var r *MsgReader
t, r, err = c.rxMsg() t, r, err = c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
@ -348,9 +349,9 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
for { for {
var t byte var t byte
var bodySize int32 var r *MsgReader
t, bodySize, err = c.rxMsgHeader() t, r, err = c.rxMsg()
if err != nil { if err != nil {
return err return err
} }
@ -363,19 +364,12 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
} }
if softErr != nil { if softErr != nil {
c.rxMsgBody(bodySize) // Read and discard rest of message // Read and discard rest of message
continue continue
} }
softErr = c.rxDataRowValueTo(w, bodySize) softErr = c.rxDataRowValueTo(w, r)
} else { } else {
var body *bytes.Buffer
body, err = c.rxMsgBody(bodySize)
if err != nil {
return err
}
r := (*MessageReader)(body)
switch t { switch t {
case readyForQuery: case readyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(r)
@ -392,44 +386,20 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
} }
} }
func (c *Conn) rxDataRowValueTo(w io.Writer, bodySize int32) (err error) { func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
b := make([]byte, 2) columnCount := r.ReadInt16()
_, err = io.ReadFull(c.reader, b)
if err != nil {
c.die(err)
return
}
columnCount := int16(binary.BigEndian.Uint16(b))
if columnCount != 1 { if columnCount != 1 {
// Read the rest of the data row so it can be discarded return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
if _, err = io.CopyN(ioutil.Discard, c.reader, int64(bodySize-2)); err != nil {
c.die(err)
return
}
err = UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
return
} }
b = make([]byte, 4) valueSize := r.ReadInt32()
_, err = io.ReadFull(c.reader, b)
if err != nil {
c.die(err)
return
}
valueSize := int32(binary.BigEndian.Uint32(b))
if valueSize == -1 { if valueSize == -1 {
err = errors.New("SelectValueTo cannot handle null") return errors.New("SelectValueTo cannot handle null")
return
} }
_, err = io.CopyN(w, c.reader, int64(valueSize)) r.CopyN(w, valueSize)
if err != nil {
c.die(err)
}
return return r.Err()
} }
// Prepare creates a prepared statement with name and sql. sql can contain placeholders // Prepare creates a prepared statement with name and sql. sql can contain placeholders
@ -475,7 +445,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
for { for {
var t byte var t byte
var r *MessageReader var r *MsgReader
t, r, err := c.rxMsg() t, r, err := c.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
@ -563,7 +533,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
} }
var t byte var t byte
var r *MessageReader var r *MsgReader
if t, r, err = c.rxMsg(); err == nil { if t, r, err = c.rxMsg(); err == nil {
if err = c.processContextFreeMsg(t, r); err != nil { if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err return nil, err
@ -653,12 +623,12 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
} }
func (rr *RowReader) ReadString(qr *QueryResult) string { func (rr *RowReader) ReadString(qr *QueryResult) string {
_, size, ok := qr.NextColumn() fd, size, ok := qr.NextColumn()
if !ok { if !ok {
return "" return ""
} }
return qr.mr.ReadString(size) return decodeText(qr, fd, size)
} }
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
@ -671,7 +641,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil { if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
return vt.Decode(qr, fd, size) return vt.Decode(qr, fd, size)
} else { } else {
return qr.mr.ReadString(size) return decodeText(qr, fd, size)
} }
} else { } else {
return nil return nil
@ -681,7 +651,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
type QueryResult struct { type QueryResult struct {
pool *ConnPool pool *ConnPool
conn *Conn conn *Conn
mr *MessageReader mr *MsgReader
fields []FieldDescription fields []FieldDescription
rowCount int rowCount int
columnIdx int columnIdx int
@ -693,7 +663,7 @@ func (qr *QueryResult) FieldDescriptions() []FieldDescription {
return qr.fields return qr.fields
} }
func (qr *QueryResult) MessageReader() *MessageReader { func (qr *QueryResult) MsgReader() *MsgReader {
return qr.mr return qr.mr
} }
@ -771,8 +741,8 @@ func (qr *QueryResult) NextRow() bool {
qr.close() qr.close()
return false return false
case dataRow: case dataRow:
fieldCount := int(r.ReadInt16()) fieldCount := r.ReadInt16()
if fieldCount != len(qr.fields) { if int(fieldCount) != len(qr.fields) {
qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount))) qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount)))
return false return false
} }
@ -803,7 +773,6 @@ func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
fd := &qr.fields[qr.columnIdx] fd := &qr.fields[qr.columnIdx]
qr.columnIdx++ qr.columnIdx++
size := qr.mr.ReadInt32() size := qr.mr.ReadInt32()
return fd, size, true return fd, size, true
} }
@ -965,7 +934,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
for { for {
var t byte var t byte
var r *MessageReader var r *MsgReader
t, r, err = c.rxMsg() t, r, err = c.rxMsg()
if err != nil { if err != nil {
return commandTag, err return commandTag, err
@ -1040,7 +1009,7 @@ func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err
// Processes messages that are not exclusive to one context such as // Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages // authentication or query response. The response to these messages
// is the same regardless of when they occur. // is the same regardless of when they occur.
func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) { func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
switch t { switch t {
case 'S': case 'S':
c.rxParameterStatus(r) c.rxParameterStatus(r)
@ -1050,69 +1019,28 @@ func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
case noticeResponse: case noticeResponse:
return nil return nil
case notificationResponse: case notificationResponse:
return c.rxNotificationResponse(r) c.rxNotificationResponse(r)
return nil
default: default:
return fmt.Errorf("Received unknown message type: %c", t) return fmt.Errorf("Received unknown message type: %c", t)
} }
} }
func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) { func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
var bodySize int32
t, bodySize, err = c.rxMsgHeader()
if err != nil {
return
}
var body *bytes.Buffer
if body, err = c.rxMsgBody(bodySize); err != nil {
return
}
r = (*MessageReader)(body)
return
}
func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) {
if !c.alive { if !c.alive {
return 0, 0, DeadConnError return 0, nil, DeadConnError
} }
t, err = c.reader.ReadByte() t, err = c.mr.rxMsg()
if err != nil { if err != nil {
c.die(err) c.die(err)
return 0, 0, err
} }
b := make([]byte, 4) return t, &c.mr, err
_, err = io.ReadFull(c.reader, b)
if err != nil {
c.die(err)
return 0, 0, err
} }
bodySize = int32(binary.BigEndian.Uint32(b)) func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) {
switch r.ReadInt32() {
bodySize -= 4 // remove self from size
return t, bodySize, err
}
func (c *Conn) rxMsgBody(bodySize int32) (*bytes.Buffer, error) {
if !c.alive {
return nil, DeadConnError
}
buf := c.getBuf()
_, err := io.CopyN(buf, c.reader, int64(bodySize))
if err != nil {
c.die(err)
return nil, err
}
return buf, nil
}
func (c *Conn) rxAuthenticationX(r *MessageReader) (err error) {
code := r.ReadInt32()
switch code {
case 0: // AuthenticationOk case 0: // AuthenticationOk
case 3: // AuthenticationCleartextPassword case 3: // AuthenticationCleartextPassword
err = c.txPasswordMessage(c.config.Password) err = c.txPasswordMessage(c.config.Password)
@ -1133,13 +1061,13 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }
func (c *Conn) rxParameterStatus(r *MessageReader) { func (c *Conn) rxParameterStatus(r *MsgReader) {
key := r.ReadCString() key := r.ReadCString()
value := r.ReadCString() value := r.ReadCString()
c.RuntimeParams[key] = value c.RuntimeParams[key] = value
} }
func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) { func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) {
for { for {
switch r.ReadByte() { switch r.ReadByte() {
case 'S': case 'S':
@ -1159,16 +1087,16 @@ func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
} }
} }
func (c *Conn) rxBackendKeyData(r *MessageReader) { func (c *Conn) rxBackendKeyData(r *MsgReader) {
c.Pid = r.ReadInt32() c.Pid = r.ReadInt32()
c.SecretKey = r.ReadInt32() c.SecretKey = r.ReadInt32()
} }
func (c *Conn) rxReadyForQuery(r *MessageReader) { func (c *Conn) rxReadyForQuery(r *MsgReader) {
c.TxStatus = r.ReadByte() c.TxStatus = r.ReadByte()
} }
func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) { func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) {
fieldCount := r.ReadInt16() fieldCount := r.ReadInt16()
fields = make([]FieldDescription, fieldCount) fields = make([]FieldDescription, fieldCount)
for i := int16(0); i < fieldCount; i++ { for i := int16(0); i < fieldCount; i++ {
@ -1184,26 +1112,26 @@ func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
return return
} }
func (c *Conn) rxParameterDescription(r *MessageReader) (parameters []Oid) { func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) {
parameterCount := r.ReadInt16() parameterCount := r.ReadInt16()
parameters = make([]Oid, 0, parameterCount) parameters = make([]Oid, 0, parameterCount)
for i := int16(0); i < parameterCount; i++ { for i := int16(0); i < parameterCount; i++ {
parameters = append(parameters, r.ReadOid()) parameters = append(parameters, r.ReadOid())
} }
return return
} }
func (c *Conn) rxCommandComplete(r *MessageReader) string { func (c *Conn) rxCommandComplete(r *MsgReader) string {
return r.ReadCString() return r.ReadCString()
} }
func (c *Conn) rxNotificationResponse(r *MessageReader) (err error) { func (c *Conn) rxNotificationResponse(r *MsgReader) {
n := new(Notification) n := new(Notification)
n.Pid = r.ReadInt32() n.Pid = r.ReadInt32()
n.Channel = r.ReadCString() n.Channel = r.ReadCString()
n.Payload = r.ReadCString() n.Payload = r.ReadCString()
c.notifications = append(c.notifications, n) c.notifications = append(c.notifications, n)
return
} }
func (c *Conn) startTLS() (err error) { func (c *Conn) startTLS() (err error) {
@ -1280,12 +1208,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
// its internal byte array, check on the size and create a new bytes.Buffer so the // its internal byte array, check on the size and create a new bytes.Buffer so the
// old one can get GC'ed // old one can get GC'ed
func (c *Conn) getBuf() *bytes.Buffer { func (c *Conn) getBuf() *bytes.Buffer {
c.buf.Reset() return &bytes.Buffer{}
if cap(c.buf.Bytes()) > c.bufSize {
c.logger.Debug(fmt.Sprintf("c.buf (%d) is larger than c.bufSize (%d) -- resetting", cap(c.buf.Bytes()), c.bufSize))
c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize))
}
return c.buf
} }
func (c *Conn) die(err error) { func (c *Conn) die(err error) {

View File

@ -51,7 +51,7 @@ func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Poin
switch fd.FormatCode { switch fd.FormatCode {
case pgx.TextFormatCode: case pgx.TextFormatCode:
s := qr.MessageReader().ReadString(size) s := qr.MsgReader().ReadString(size)
match := pointRegexp.FindStringSubmatch(s) match := pointRegexp.FindStringSubmatch(s)
if match == nil { if match == nil {
qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)))

View File

@ -1,47 +0,0 @@
package pgx
import (
"bytes"
"encoding/binary"
)
// MessageReader is a helper that reads values from a PostgreSQL message.
type MessageReader bytes.Buffer
func (r *MessageReader) ReadByte() (b byte) {
b, _ = (*bytes.Buffer)(r).ReadByte()
return
}
func (r *MessageReader) ReadInt16() (n int16) {
b := (*bytes.Buffer)(r).Next(2)
return int16(binary.BigEndian.Uint16(b))
}
func (r *MessageReader) ReadInt32() (n int32) {
b := (*bytes.Buffer)(r).Next(4)
return int32(binary.BigEndian.Uint32(b))
}
func (r *MessageReader) ReadInt64() (n int64) {
b := (*bytes.Buffer)(r).Next(8)
return int64(binary.BigEndian.Uint64(b))
}
func (r *MessageReader) ReadOid() (oid Oid) {
b := (*bytes.Buffer)(r).Next(4)
return Oid(binary.BigEndian.Uint32(b))
}
// ReadString reads a null terminated string
func (r *MessageReader) ReadCString() (s string) {
b, _ := (*bytes.Buffer)(r).ReadBytes(0)
return string(b[:len(b)-1])
}
// ReadString reads count bytes and returns as string
func (r *MessageReader) ReadString(count int32) (s string) {
size := int(count)
b := (*bytes.Buffer)(r).Next(size)
return string(b)
}

197
msg_reader.go Normal file
View File

@ -0,0 +1,197 @@
package pgx
import (
"bufio"
"encoding/binary"
"errors"
"io"
"io/ioutil"
)
// MsgReader is a helper that reads values from a PostgreSQL message.
type MsgReader struct {
reader *bufio.Reader
buf [128]byte
msgBytesRemaining int32
err error
}
// Err returns any error that the MsgReader has experienced
func (r *MsgReader) Err() error {
return r.err
}
// Fatal tells r that a Fatal error has occurred
func (r *MsgReader) Fatal(err error) {
r.err = err
}
// rxMsg reads the type and size of the next message.
func (r *MsgReader) rxMsg() (t byte, err error) {
if r.err != nil {
return 0, err
}
if r.msgBytesRemaining > 0 {
io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining))
}
t, err = r.reader.ReadByte()
b := r.buf[0:4]
_, err = io.ReadFull(r.reader, b)
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b)) - 4
return t, err
}
func (r *MsgReader) ReadByte() byte {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 1
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.ReadByte()
if err != nil {
r.Fatal(err)
return 0
}
return b
}
func (r *MsgReader) ReadInt16() int16 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b := r.buf[0:2]
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return 0
}
return int16(binary.BigEndian.Uint16(b))
}
func (r *MsgReader) ReadInt32() int32 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b := r.buf[0:4]
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return 0
}
return int32(binary.BigEndian.Uint32(b))
}
func (r *MsgReader) ReadInt64() int64 {
if r.err != nil {
return 0
}
r.msgBytesRemaining -= 8
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return 0
}
b := r.buf[0:8]
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return 0
}
return int64(binary.BigEndian.Uint64(b))
}
func (r *MsgReader) ReadOid() Oid {
return Oid(r.ReadInt32())
}
// ReadCString reads a null terminated string
func (r *MsgReader) ReadCString() string {
if r.err != nil {
return ""
}
b, err := r.reader.ReadBytes(0)
if err != nil {
r.Fatal(err)
return ""
}
r.msgBytesRemaining -= int32(len(b))
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return ""
}
return string(b[0 : len(b)-1])
}
// ReadString reads count bytes and returns as string
func (r *MsgReader) ReadString(count int32) string {
if r.err != nil {
return ""
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return ""
}
var b []byte
if count <= int32(len(r.buf)) {
b = r.buf[0:int(count)]
} else {
b = make([]byte, int(count))
}
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.Fatal(err)
return ""
}
return string(b)
}
func (r *MsgReader) CopyN(w io.Writer, count int32) {
if r.err != nil {
return
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
r.Fatal(errors.New("read past end of message"))
return
}
_, err := io.CopyN(w, r.reader, int64(count))
if err != nil {
r.Fatal(err)
}
}