diff --git a/conn.go b/conn.go index 87bdff5d..1c11c5d8 100644 --- a/conn.go +++ b/conn.go @@ -14,7 +14,6 @@ import ( "fmt" log "gopkg.in/inconshreveable/log15.v2" "io" - "io/ioutil" "net" "net/url" "os" @@ -52,8 +51,8 @@ type Conn struct { conn net.Conn // the underlying TCP or unix domain socket connection reader *bufio.Reader // buffered reader to improve read performance wbuf [1024]byte - buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc - bufSize int // desired size of buf + buf *bytes.Buffer + bufSize int Pid int32 // backend pid SecretKey int32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server @@ -65,6 +64,7 @@ type Conn struct { causeOfDeath error logger log.Logger qr QueryResult + mr MsgReader } type PreparedStatement struct { @@ -197,6 +197,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { } c.reader = bufio.NewReader(c.conn) + c.mr.reader = c.reader msg := newStartupMessage() msg.options["user"] = c.config.User @@ -209,7 +210,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { for { var t byte - var r *MessageReader + var r *MsgReader t, r, err = c.rxMsg() if err != nil { return nil, err @@ -348,9 +349,9 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{}) for { var t byte - var bodySize int32 + var r *MsgReader - t, bodySize, err = c.rxMsgHeader() + t, r, err = c.rxMsg() if err != nil { return err } @@ -363,19 +364,12 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{}) } if softErr != nil { - c.rxMsgBody(bodySize) // Read and discard rest of message + // Read and discard rest of message continue } - softErr = c.rxDataRowValueTo(w, bodySize) + softErr = c.rxDataRowValueTo(w, r) } else { - var body *bytes.Buffer - body, err = c.rxMsgBody(bodySize) - if err != nil { - return err - } - - r := (*MessageReader)(body) switch t { case readyForQuery: c.rxReadyForQuery(r) @@ -392,44 +386,20 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{}) } } -func (c *Conn) rxDataRowValueTo(w io.Writer, bodySize int32) (err error) { - b := make([]byte, 2) - _, err = io.ReadFull(c.reader, b) - if err != nil { - c.die(err) - return - } - columnCount := int16(binary.BigEndian.Uint16(b)) - +func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error { + columnCount := r.ReadInt16() if columnCount != 1 { - // Read the rest of the data row so it can be discarded - if _, err = io.CopyN(ioutil.Discard, c.reader, int64(bodySize-2)); err != nil { - c.die(err) - return - } - err = UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount} - return + return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount} } - b = make([]byte, 4) - _, err = io.ReadFull(c.reader, b) - if err != nil { - c.die(err) - return - } - valueSize := int32(binary.BigEndian.Uint32(b)) - + valueSize := r.ReadInt32() if valueSize == -1 { - err = errors.New("SelectValueTo cannot handle null") - return + return errors.New("SelectValueTo cannot handle null") } - _, err = io.CopyN(w, c.reader, int64(valueSize)) - if err != nil { - c.die(err) - } + r.CopyN(w, valueSize) - return + return r.Err() } // Prepare creates a prepared statement with name and sql. sql can contain placeholders @@ -475,7 +445,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { for { var t byte - var r *MessageReader + var r *MsgReader t, r, err := c.rxMsg() if err != nil { return nil, err @@ -563,7 +533,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) } var t byte - var r *MessageReader + var r *MsgReader if t, r, err = c.rxMsg(); err == nil { if err = c.processContextFreeMsg(t, r); err != nil { return nil, err @@ -653,12 +623,12 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time { } func (rr *RowReader) ReadString(qr *QueryResult) string { - _, size, ok := qr.NextColumn() + fd, size, ok := qr.NextColumn() if !ok { return "" } - return qr.mr.ReadString(size) + return decodeText(qr, fd, size) } func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { @@ -671,7 +641,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil { return vt.Decode(qr, fd, size) } else { - return qr.mr.ReadString(size) + return decodeText(qr, fd, size) } } else { return nil @@ -681,7 +651,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { type QueryResult struct { pool *ConnPool conn *Conn - mr *MessageReader + mr *MsgReader fields []FieldDescription rowCount int columnIdx int @@ -693,7 +663,7 @@ func (qr *QueryResult) FieldDescriptions() []FieldDescription { return qr.fields } -func (qr *QueryResult) MessageReader() *MessageReader { +func (qr *QueryResult) MsgReader() *MsgReader { return qr.mr } @@ -771,8 +741,8 @@ func (qr *QueryResult) NextRow() bool { qr.close() return false case dataRow: - fieldCount := int(r.ReadInt16()) - if fieldCount != len(qr.fields) { + fieldCount := r.ReadInt16() + if int(fieldCount) != len(qr.fields) { qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount))) return false } @@ -803,7 +773,6 @@ func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) { fd := &qr.fields[qr.columnIdx] qr.columnIdx++ size := qr.mr.ReadInt32() - return fd, size, true } @@ -965,7 +934,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag for { var t byte - var r *MessageReader + var r *MsgReader t, r, err = c.rxMsg() if err != nil { return commandTag, err @@ -1040,7 +1009,7 @@ func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. -func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) { +func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) { switch t { case 'S': c.rxParameterStatus(r) @@ -1050,69 +1019,28 @@ func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) { case noticeResponse: return nil case notificationResponse: - return c.rxNotificationResponse(r) + c.rxNotificationResponse(r) + return nil default: return fmt.Errorf("Received unknown message type: %c", t) } } -func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) { - var bodySize int32 - t, bodySize, err = c.rxMsgHeader() - if err != nil { - return - } - - var body *bytes.Buffer - if body, err = c.rxMsgBody(bodySize); err != nil { - return - } - r = (*MessageReader)(body) - return -} - -func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) { +func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) { if !c.alive { - return 0, 0, DeadConnError + return 0, nil, DeadConnError } - t, err = c.reader.ReadByte() + t, err = c.mr.rxMsg() if err != nil { c.die(err) - return 0, 0, err } - b := make([]byte, 4) - _, err = io.ReadFull(c.reader, b) - if err != nil { - c.die(err) - return 0, 0, err - } - - bodySize = int32(binary.BigEndian.Uint32(b)) - - bodySize -= 4 // remove self from size - return t, bodySize, err + return t, &c.mr, err } -func (c *Conn) rxMsgBody(bodySize int32) (*bytes.Buffer, error) { - if !c.alive { - return nil, DeadConnError - } - - buf := c.getBuf() - _, err := io.CopyN(buf, c.reader, int64(bodySize)) - if err != nil { - c.die(err) - return nil, err - } - - return buf, nil -} - -func (c *Conn) rxAuthenticationX(r *MessageReader) (err error) { - code := r.ReadInt32() - switch code { +func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) { + switch r.ReadInt32() { case 0: // AuthenticationOk case 3: // AuthenticationCleartextPassword err = c.txPasswordMessage(c.config.Password) @@ -1133,13 +1061,13 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Conn) rxParameterStatus(r *MessageReader) { +func (c *Conn) rxParameterStatus(r *MsgReader) { key := r.ReadCString() value := r.ReadCString() c.RuntimeParams[key] = value } -func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) { +func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) { for { switch r.ReadByte() { case 'S': @@ -1159,16 +1087,16 @@ func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) { } } -func (c *Conn) rxBackendKeyData(r *MessageReader) { +func (c *Conn) rxBackendKeyData(r *MsgReader) { c.Pid = r.ReadInt32() c.SecretKey = r.ReadInt32() } -func (c *Conn) rxReadyForQuery(r *MessageReader) { +func (c *Conn) rxReadyForQuery(r *MsgReader) { c.TxStatus = r.ReadByte() } -func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) { +func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) { fieldCount := r.ReadInt16() fields = make([]FieldDescription, fieldCount) for i := int16(0); i < fieldCount; i++ { @@ -1184,26 +1112,26 @@ func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) { return } -func (c *Conn) rxParameterDescription(r *MessageReader) (parameters []Oid) { +func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) { parameterCount := r.ReadInt16() parameters = make([]Oid, 0, parameterCount) + for i := int16(0); i < parameterCount; i++ { parameters = append(parameters, r.ReadOid()) } return } -func (c *Conn) rxCommandComplete(r *MessageReader) string { +func (c *Conn) rxCommandComplete(r *MsgReader) string { return r.ReadCString() } -func (c *Conn) rxNotificationResponse(r *MessageReader) (err error) { +func (c *Conn) rxNotificationResponse(r *MsgReader) { n := new(Notification) n.Pid = r.ReadInt32() n.Channel = r.ReadCString() n.Payload = r.ReadCString() c.notifications = append(c.notifications, n) - return } func (c *Conn) startTLS() (err error) { @@ -1280,12 +1208,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) { // its internal byte array, check on the size and create a new bytes.Buffer so the // old one can get GC'ed func (c *Conn) getBuf() *bytes.Buffer { - c.buf.Reset() - if cap(c.buf.Bytes()) > c.bufSize { - c.logger.Debug(fmt.Sprintf("c.buf (%d) is larger than c.bufSize (%d) -- resetting", cap(c.buf.Bytes()), c.bufSize)) - c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize)) - } - return c.buf + return &bytes.Buffer{} } func (c *Conn) die(err error) { diff --git a/example_value_transcoder_test.go b/example_value_transcoder_test.go index 554c9343..fe5cb632 100644 --- a/example_value_transcoder_test.go +++ b/example_value_transcoder_test.go @@ -51,7 +51,7 @@ func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Poin switch fd.FormatCode { case pgx.TextFormatCode: - s := qr.MessageReader().ReadString(size) + s := qr.MsgReader().ReadString(size) match := pointRegexp.FindStringSubmatch(s) if match == nil { qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) diff --git a/message_reader.go b/message_reader.go deleted file mode 100644 index 938dbd17..00000000 --- a/message_reader.go +++ /dev/null @@ -1,47 +0,0 @@ -package pgx - -import ( - "bytes" - "encoding/binary" -) - -// MessageReader is a helper that reads values from a PostgreSQL message. -type MessageReader bytes.Buffer - -func (r *MessageReader) ReadByte() (b byte) { - b, _ = (*bytes.Buffer)(r).ReadByte() - return -} - -func (r *MessageReader) ReadInt16() (n int16) { - b := (*bytes.Buffer)(r).Next(2) - return int16(binary.BigEndian.Uint16(b)) -} - -func (r *MessageReader) ReadInt32() (n int32) { - b := (*bytes.Buffer)(r).Next(4) - return int32(binary.BigEndian.Uint32(b)) -} - -func (r *MessageReader) ReadInt64() (n int64) { - b := (*bytes.Buffer)(r).Next(8) - return int64(binary.BigEndian.Uint64(b)) -} - -func (r *MessageReader) ReadOid() (oid Oid) { - b := (*bytes.Buffer)(r).Next(4) - return Oid(binary.BigEndian.Uint32(b)) -} - -// ReadString reads a null terminated string -func (r *MessageReader) ReadCString() (s string) { - b, _ := (*bytes.Buffer)(r).ReadBytes(0) - return string(b[:len(b)-1]) -} - -// ReadString reads count bytes and returns as string -func (r *MessageReader) ReadString(count int32) (s string) { - size := int(count) - b := (*bytes.Buffer)(r).Next(size) - return string(b) -} diff --git a/msg_reader.go b/msg_reader.go new file mode 100644 index 00000000..55bd6a11 --- /dev/null +++ b/msg_reader.go @@ -0,0 +1,197 @@ +package pgx + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" +) + +// MsgReader is a helper that reads values from a PostgreSQL message. +type MsgReader struct { + reader *bufio.Reader + buf [128]byte + msgBytesRemaining int32 + err error +} + +// Err returns any error that the MsgReader has experienced +func (r *MsgReader) Err() error { + return r.err +} + +// Fatal tells r that a Fatal error has occurred +func (r *MsgReader) Fatal(err error) { + r.err = err +} + +// rxMsg reads the type and size of the next message. +func (r *MsgReader) rxMsg() (t byte, err error) { + if r.err != nil { + return 0, err + } + + if r.msgBytesRemaining > 0 { + io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining)) + } + + t, err = r.reader.ReadByte() + b := r.buf[0:4] + _, err = io.ReadFull(r.reader, b) + r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b)) - 4 + return t, err +} + +func (r *MsgReader) ReadByte() byte { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 1 + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.ReadByte() + if err != nil { + r.Fatal(err) + return 0 + } + + return b +} + +func (r *MsgReader) ReadInt16() int16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return 0 + } + + b := r.buf[0:2] + _, err := io.ReadFull(r.reader, b) + if err != nil { + r.Fatal(err) + return 0 + } + + return int16(binary.BigEndian.Uint16(b)) +} + +func (r *MsgReader) ReadInt32() int32 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 4 + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return 0 + } + + b := r.buf[0:4] + _, err := io.ReadFull(r.reader, b) + if err != nil { + r.Fatal(err) + return 0 + } + + return int32(binary.BigEndian.Uint32(b)) +} + +func (r *MsgReader) ReadInt64() int64 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 8 + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return 0 + } + + b := r.buf[0:8] + _, err := io.ReadFull(r.reader, b) + if err != nil { + r.Fatal(err) + return 0 + } + + return int64(binary.BigEndian.Uint64(b)) +} + +func (r *MsgReader) ReadOid() Oid { + return Oid(r.ReadInt32()) +} + +// ReadCString reads a null terminated string +func (r *MsgReader) ReadCString() string { + if r.err != nil { + return "" + } + + b, err := r.reader.ReadBytes(0) + if err != nil { + r.Fatal(err) + return "" + } + + r.msgBytesRemaining -= int32(len(b)) + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return "" + } + + return string(b[0 : len(b)-1]) +} + +// ReadString reads count bytes and returns as string +func (r *MsgReader) ReadString(count int32) string { + if r.err != nil { + return "" + } + + r.msgBytesRemaining -= count + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return "" + } + + var b []byte + if count <= int32(len(r.buf)) { + b = r.buf[0:int(count)] + } else { + b = make([]byte, int(count)) + } + + _, err := io.ReadFull(r.reader, b) + if err != nil { + r.Fatal(err) + return "" + } + + return string(b) +} + +func (r *MsgReader) CopyN(w io.Writer, count int32) { + if r.err != nil { + return + } + + r.msgBytesRemaining -= count + if r.msgBytesRemaining < 0 { + r.Fatal(errors.New("read past end of message")) + return + } + + _, err := io.CopyN(w, r.reader, int64(count)) + if err != nil { + r.Fatal(err) + } +}