mirror of https://github.com/jackc/pgx.git
MsgReader no longer uses double buffering
parent
78b8e0b6f2
commit
b25aea5c52
167
conn.go
167
conn.go
|
@ -14,7 +14,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
log "gopkg.in/inconshreveable/log15.v2"
|
log "gopkg.in/inconshreveable/log15.v2"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
@ -52,8 +51,8 @@ type Conn struct {
|
||||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||||
reader *bufio.Reader // buffered reader to improve read performance
|
reader *bufio.Reader // buffered reader to improve read performance
|
||||||
wbuf [1024]byte
|
wbuf [1024]byte
|
||||||
buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc
|
buf *bytes.Buffer
|
||||||
bufSize int // desired size of buf
|
bufSize int
|
||||||
Pid int32 // backend pid
|
Pid int32 // backend pid
|
||||||
SecretKey int32 // key to use to send a cancel query message to the server
|
SecretKey int32 // key to use to send a cancel query message to the server
|
||||||
RuntimeParams map[string]string // parameters that have been reported by the server
|
RuntimeParams map[string]string // parameters that have been reported by the server
|
||||||
|
@ -65,6 +64,7 @@ type Conn struct {
|
||||||
causeOfDeath error
|
causeOfDeath error
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
qr QueryResult
|
qr QueryResult
|
||||||
|
mr MsgReader
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStatement struct {
|
type PreparedStatement struct {
|
||||||
|
@ -197,6 +197,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.reader = bufio.NewReader(c.conn)
|
c.reader = bufio.NewReader(c.conn)
|
||||||
|
c.mr.reader = c.reader
|
||||||
|
|
||||||
msg := newStartupMessage()
|
msg := newStartupMessage()
|
||||||
msg.options["user"] = c.config.User
|
msg.options["user"] = c.config.User
|
||||||
|
@ -209,7 +210,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var t byte
|
var t byte
|
||||||
var r *MessageReader
|
var r *MsgReader
|
||||||
t, r, err = c.rxMsg()
|
t, r, err = c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -348,9 +349,9 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var t byte
|
var t byte
|
||||||
var bodySize int32
|
var r *MsgReader
|
||||||
|
|
||||||
t, bodySize, err = c.rxMsgHeader()
|
t, r, err = c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -363,19 +364,12 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
if softErr != nil {
|
if softErr != nil {
|
||||||
c.rxMsgBody(bodySize) // Read and discard rest of message
|
// Read and discard rest of message
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
softErr = c.rxDataRowValueTo(w, bodySize)
|
softErr = c.rxDataRowValueTo(w, r)
|
||||||
} else {
|
} else {
|
||||||
var body *bytes.Buffer
|
|
||||||
body, err = c.rxMsgBody(bodySize)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r := (*MessageReader)(body)
|
|
||||||
switch t {
|
switch t {
|
||||||
case readyForQuery:
|
case readyForQuery:
|
||||||
c.rxReadyForQuery(r)
|
c.rxReadyForQuery(r)
|
||||||
|
@ -392,44 +386,20 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxDataRowValueTo(w io.Writer, bodySize int32) (err error) {
|
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
|
||||||
b := make([]byte, 2)
|
columnCount := r.ReadInt16()
|
||||||
_, err = io.ReadFull(c.reader, b)
|
|
||||||
if err != nil {
|
|
||||||
c.die(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
columnCount := int16(binary.BigEndian.Uint16(b))
|
|
||||||
|
|
||||||
if columnCount != 1 {
|
if columnCount != 1 {
|
||||||
// Read the rest of the data row so it can be discarded
|
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
|
||||||
if _, err = io.CopyN(ioutil.Discard, c.reader, int64(bodySize-2)); err != nil {
|
|
||||||
c.die(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
b = make([]byte, 4)
|
valueSize := r.ReadInt32()
|
||||||
_, err = io.ReadFull(c.reader, b)
|
|
||||||
if err != nil {
|
|
||||||
c.die(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
valueSize := int32(binary.BigEndian.Uint32(b))
|
|
||||||
|
|
||||||
if valueSize == -1 {
|
if valueSize == -1 {
|
||||||
err = errors.New("SelectValueTo cannot handle null")
|
return errors.New("SelectValueTo cannot handle null")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = io.CopyN(w, c.reader, int64(valueSize))
|
r.CopyN(w, valueSize)
|
||||||
if err != nil {
|
|
||||||
c.die(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return r.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
||||||
|
@ -475,7 +445,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var t byte
|
var t byte
|
||||||
var r *MessageReader
|
var r *MsgReader
|
||||||
t, r, err := c.rxMsg()
|
t, r, err := c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -563,7 +533,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var t byte
|
var t byte
|
||||||
var r *MessageReader
|
var r *MsgReader
|
||||||
if t, r, err = c.rxMsg(); err == nil {
|
if t, r, err = c.rxMsg(); err == nil {
|
||||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -653,12 +623,12 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadString(qr *QueryResult) string {
|
func (rr *RowReader) ReadString(qr *QueryResult) string {
|
||||||
_, size, ok := qr.NextColumn()
|
fd, size, ok := qr.NextColumn()
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return qr.mr.ReadString(size)
|
return decodeText(qr, fd, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
||||||
|
@ -671,7 +641,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
||||||
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
|
if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil {
|
||||||
return vt.Decode(qr, fd, size)
|
return vt.Decode(qr, fd, size)
|
||||||
} else {
|
} else {
|
||||||
return qr.mr.ReadString(size)
|
return decodeText(qr, fd, size)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil
|
return nil
|
||||||
|
@ -681,7 +651,7 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
||||||
type QueryResult struct {
|
type QueryResult struct {
|
||||||
pool *ConnPool
|
pool *ConnPool
|
||||||
conn *Conn
|
conn *Conn
|
||||||
mr *MessageReader
|
mr *MsgReader
|
||||||
fields []FieldDescription
|
fields []FieldDescription
|
||||||
rowCount int
|
rowCount int
|
||||||
columnIdx int
|
columnIdx int
|
||||||
|
@ -693,7 +663,7 @@ func (qr *QueryResult) FieldDescriptions() []FieldDescription {
|
||||||
return qr.fields
|
return qr.fields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qr *QueryResult) MessageReader() *MessageReader {
|
func (qr *QueryResult) MsgReader() *MsgReader {
|
||||||
return qr.mr
|
return qr.mr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -771,8 +741,8 @@ func (qr *QueryResult) NextRow() bool {
|
||||||
qr.close()
|
qr.close()
|
||||||
return false
|
return false
|
||||||
case dataRow:
|
case dataRow:
|
||||||
fieldCount := int(r.ReadInt16())
|
fieldCount := r.ReadInt16()
|
||||||
if fieldCount != len(qr.fields) {
|
if int(fieldCount) != len(qr.fields) {
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount)))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -803,7 +773,6 @@ func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) {
|
||||||
fd := &qr.fields[qr.columnIdx]
|
fd := &qr.fields[qr.columnIdx]
|
||||||
qr.columnIdx++
|
qr.columnIdx++
|
||||||
size := qr.mr.ReadInt32()
|
size := qr.mr.ReadInt32()
|
||||||
|
|
||||||
return fd, size, true
|
return fd, size, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -965,7 +934,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var t byte
|
var t byte
|
||||||
var r *MessageReader
|
var r *MsgReader
|
||||||
t, r, err = c.rxMsg()
|
t, r, err = c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return commandTag, err
|
return commandTag, err
|
||||||
|
@ -1040,7 +1009,7 @@ func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err
|
||||||
// Processes messages that are not exclusive to one context such as
|
// Processes messages that are not exclusive to one context such as
|
||||||
// authentication or query response. The response to these messages
|
// authentication or query response. The response to these messages
|
||||||
// is the same regardless of when they occur.
|
// is the same regardless of when they occur.
|
||||||
func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
|
func (c *Conn) processContextFreeMsg(t byte, r *MsgReader) (err error) {
|
||||||
switch t {
|
switch t {
|
||||||
case 'S':
|
case 'S':
|
||||||
c.rxParameterStatus(r)
|
c.rxParameterStatus(r)
|
||||||
|
@ -1050,69 +1019,28 @@ func (c *Conn) processContextFreeMsg(t byte, r *MessageReader) (err error) {
|
||||||
case noticeResponse:
|
case noticeResponse:
|
||||||
return nil
|
return nil
|
||||||
case notificationResponse:
|
case notificationResponse:
|
||||||
return c.rxNotificationResponse(r)
|
c.rxNotificationResponse(r)
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Received unknown message type: %c", t)
|
return fmt.Errorf("Received unknown message type: %c", t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) {
|
func (c *Conn) rxMsg() (t byte, r *MsgReader, err error) {
|
||||||
var bodySize int32
|
|
||||||
t, bodySize, err = c.rxMsgHeader()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var body *bytes.Buffer
|
|
||||||
if body, err = c.rxMsgBody(bodySize); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r = (*MessageReader)(body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) rxMsgHeader() (t byte, bodySize int32, err error) {
|
|
||||||
if !c.alive {
|
if !c.alive {
|
||||||
return 0, 0, DeadConnError
|
return 0, nil, DeadConnError
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err = c.reader.ReadByte()
|
t, err = c.mr.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.die(err)
|
c.die(err)
|
||||||
return 0, 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
b := make([]byte, 4)
|
return t, &c.mr, err
|
||||||
_, err = io.ReadFull(c.reader, b)
|
|
||||||
if err != nil {
|
|
||||||
c.die(err)
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bodySize = int32(binary.BigEndian.Uint32(b))
|
func (c *Conn) rxAuthenticationX(r *MsgReader) (err error) {
|
||||||
|
switch r.ReadInt32() {
|
||||||
bodySize -= 4 // remove self from size
|
|
||||||
return t, bodySize, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) rxMsgBody(bodySize int32) (*bytes.Buffer, error) {
|
|
||||||
if !c.alive {
|
|
||||||
return nil, DeadConnError
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := c.getBuf()
|
|
||||||
_, err := io.CopyN(buf, c.reader, int64(bodySize))
|
|
||||||
if err != nil {
|
|
||||||
c.die(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) rxAuthenticationX(r *MessageReader) (err error) {
|
|
||||||
code := r.ReadInt32()
|
|
||||||
switch code {
|
|
||||||
case 0: // AuthenticationOk
|
case 0: // AuthenticationOk
|
||||||
case 3: // AuthenticationCleartextPassword
|
case 3: // AuthenticationCleartextPassword
|
||||||
err = c.txPasswordMessage(c.config.Password)
|
err = c.txPasswordMessage(c.config.Password)
|
||||||
|
@ -1133,13 +1061,13 @@ func hexMD5(s string) string {
|
||||||
return hex.EncodeToString(hash.Sum(nil))
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxParameterStatus(r *MessageReader) {
|
func (c *Conn) rxParameterStatus(r *MsgReader) {
|
||||||
key := r.ReadCString()
|
key := r.ReadCString()
|
||||||
value := r.ReadCString()
|
value := r.ReadCString()
|
||||||
c.RuntimeParams[key] = value
|
c.RuntimeParams[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
|
func (c *Conn) rxErrorResponse(r *MsgReader) (err PgError) {
|
||||||
for {
|
for {
|
||||||
switch r.ReadByte() {
|
switch r.ReadByte() {
|
||||||
case 'S':
|
case 'S':
|
||||||
|
@ -1159,16 +1087,16 @@ func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxBackendKeyData(r *MessageReader) {
|
func (c *Conn) rxBackendKeyData(r *MsgReader) {
|
||||||
c.Pid = r.ReadInt32()
|
c.Pid = r.ReadInt32()
|
||||||
c.SecretKey = r.ReadInt32()
|
c.SecretKey = r.ReadInt32()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxReadyForQuery(r *MessageReader) {
|
func (c *Conn) rxReadyForQuery(r *MsgReader) {
|
||||||
c.TxStatus = r.ReadByte()
|
c.TxStatus = r.ReadByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
|
func (c *Conn) rxRowDescription(r *MsgReader) (fields []FieldDescription) {
|
||||||
fieldCount := r.ReadInt16()
|
fieldCount := r.ReadInt16()
|
||||||
fields = make([]FieldDescription, fieldCount)
|
fields = make([]FieldDescription, fieldCount)
|
||||||
for i := int16(0); i < fieldCount; i++ {
|
for i := int16(0); i < fieldCount; i++ {
|
||||||
|
@ -1184,26 +1112,26 @@ func (c *Conn) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxParameterDescription(r *MessageReader) (parameters []Oid) {
|
func (c *Conn) rxParameterDescription(r *MsgReader) (parameters []Oid) {
|
||||||
parameterCount := r.ReadInt16()
|
parameterCount := r.ReadInt16()
|
||||||
parameters = make([]Oid, 0, parameterCount)
|
parameters = make([]Oid, 0, parameterCount)
|
||||||
|
|
||||||
for i := int16(0); i < parameterCount; i++ {
|
for i := int16(0); i < parameterCount; i++ {
|
||||||
parameters = append(parameters, r.ReadOid())
|
parameters = append(parameters, r.ReadOid())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxCommandComplete(r *MessageReader) string {
|
func (c *Conn) rxCommandComplete(r *MsgReader) string {
|
||||||
return r.ReadCString()
|
return r.ReadCString()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxNotificationResponse(r *MessageReader) (err error) {
|
func (c *Conn) rxNotificationResponse(r *MsgReader) {
|
||||||
n := new(Notification)
|
n := new(Notification)
|
||||||
n.Pid = r.ReadInt32()
|
n.Pid = r.ReadInt32()
|
||||||
n.Channel = r.ReadCString()
|
n.Channel = r.ReadCString()
|
||||||
n.Payload = r.ReadCString()
|
n.Payload = r.ReadCString()
|
||||||
c.notifications = append(c.notifications, n)
|
c.notifications = append(c.notifications, n)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) startTLS() (err error) {
|
func (c *Conn) startTLS() (err error) {
|
||||||
|
@ -1280,12 +1208,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
|
||||||
// its internal byte array, check on the size and create a new bytes.Buffer so the
|
// its internal byte array, check on the size and create a new bytes.Buffer so the
|
||||||
// old one can get GC'ed
|
// old one can get GC'ed
|
||||||
func (c *Conn) getBuf() *bytes.Buffer {
|
func (c *Conn) getBuf() *bytes.Buffer {
|
||||||
c.buf.Reset()
|
return &bytes.Buffer{}
|
||||||
if cap(c.buf.Bytes()) > c.bufSize {
|
|
||||||
c.logger.Debug(fmt.Sprintf("c.buf (%d) is larger than c.bufSize (%d) -- resetting", cap(c.buf.Bytes()), c.bufSize))
|
|
||||||
c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize))
|
|
||||||
}
|
|
||||||
return c.buf
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) die(err error) {
|
func (c *Conn) die(err error) {
|
||||||
|
|
|
@ -51,7 +51,7 @@ func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Poin
|
||||||
|
|
||||||
switch fd.FormatCode {
|
switch fd.FormatCode {
|
||||||
case pgx.TextFormatCode:
|
case pgx.TextFormatCode:
|
||||||
s := qr.MessageReader().ReadString(size)
|
s := qr.MsgReader().ReadString(size)
|
||||||
match := pointRegexp.FindStringSubmatch(s)
|
match := pointRegexp.FindStringSubmatch(s)
|
||||||
if match == nil {
|
if match == nil {
|
||||||
qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)))
|
qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)))
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
package pgx
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MessageReader is a helper that reads values from a PostgreSQL message.
|
|
||||||
type MessageReader bytes.Buffer
|
|
||||||
|
|
||||||
func (r *MessageReader) ReadByte() (b byte) {
|
|
||||||
b, _ = (*bytes.Buffer)(r).ReadByte()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *MessageReader) ReadInt16() (n int16) {
|
|
||||||
b := (*bytes.Buffer)(r).Next(2)
|
|
||||||
return int16(binary.BigEndian.Uint16(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *MessageReader) ReadInt32() (n int32) {
|
|
||||||
b := (*bytes.Buffer)(r).Next(4)
|
|
||||||
return int32(binary.BigEndian.Uint32(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *MessageReader) ReadInt64() (n int64) {
|
|
||||||
b := (*bytes.Buffer)(r).Next(8)
|
|
||||||
return int64(binary.BigEndian.Uint64(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *MessageReader) ReadOid() (oid Oid) {
|
|
||||||
b := (*bytes.Buffer)(r).Next(4)
|
|
||||||
return Oid(binary.BigEndian.Uint32(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadString reads a null terminated string
|
|
||||||
func (r *MessageReader) ReadCString() (s string) {
|
|
||||||
b, _ := (*bytes.Buffer)(r).ReadBytes(0)
|
|
||||||
return string(b[:len(b)-1])
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadString reads count bytes and returns as string
|
|
||||||
func (r *MessageReader) ReadString(count int32) (s string) {
|
|
||||||
size := int(count)
|
|
||||||
b := (*bytes.Buffer)(r).Next(size)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MsgReader is a helper that reads values from a PostgreSQL message.
|
||||||
|
type MsgReader struct {
|
||||||
|
reader *bufio.Reader
|
||||||
|
buf [128]byte
|
||||||
|
msgBytesRemaining int32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err returns any error that the MsgReader has experienced
|
||||||
|
func (r *MsgReader) Err() error {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fatal tells r that a Fatal error has occurred
|
||||||
|
func (r *MsgReader) Fatal(err error) {
|
||||||
|
r.err = err
|
||||||
|
}
|
||||||
|
|
||||||
|
// rxMsg reads the type and size of the next message.
|
||||||
|
func (r *MsgReader) rxMsg() (t byte, err error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.msgBytesRemaining > 0 {
|
||||||
|
io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining))
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err = r.reader.ReadByte()
|
||||||
|
b := r.buf[0:4]
|
||||||
|
_, err = io.ReadFull(r.reader, b)
|
||||||
|
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b)) - 4
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MsgReader) ReadByte() byte {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= 1
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := r.reader.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MsgReader) ReadInt16() int16 {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= 2
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
b := r.buf[0:2]
|
||||||
|
_, err := io.ReadFull(r.reader, b)
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return int16(binary.BigEndian.Uint16(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MsgReader) ReadInt32() int32 {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= 4
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
b := r.buf[0:4]
|
||||||
|
_, err := io.ReadFull(r.reader, b)
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return int32(binary.BigEndian.Uint32(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MsgReader) ReadInt64() int64 {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= 8
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
b := r.buf[0:8]
|
||||||
|
_, err := io.ReadFull(r.reader, b)
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(binary.BigEndian.Uint64(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MsgReader) ReadOid() Oid {
|
||||||
|
return Oid(r.ReadInt32())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadCString reads a null terminated string
|
||||||
|
func (r *MsgReader) ReadCString() string {
|
||||||
|
if r.err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := r.reader.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= int32(len(b))
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(b[0 : len(b)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadString reads count bytes and returns as string
|
||||||
|
func (r *MsgReader) ReadString(count int32) string {
|
||||||
|
if r.err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= count
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
if count <= int32(len(r.buf)) {
|
||||||
|
b = r.buf[0:int(count)]
|
||||||
|
} else {
|
||||||
|
b = make([]byte, int(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := io.ReadFull(r.reader, b)
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MsgReader) CopyN(w io.Writer, count int32) {
|
||||||
|
if r.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining -= count
|
||||||
|
if r.msgBytesRemaining < 0 {
|
||||||
|
r.Fatal(errors.New("read past end of message"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := io.CopyN(w, r.reader, int64(count))
|
||||||
|
if err != nil {
|
||||||
|
r.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue