mirror of https://github.com/jackc/pgx.git
msgReader implemented in terms of ChunkReader
This should substantially reduce memory allocations and memory copies. It also means that PostgreSQL messages are always entirely buffered in memory before processing begins. This simplifies the message processing code. In particular, Conn.WaitForNotification is dramatically simplified by this change.v3-experimental
parent
84802ece05
commit
11b82b3ca4
115
conn.go
115
conn.go
|
@ -1,7 +1,6 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
|
@ -20,6 +19,8 @@ import (
|
|||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -283,7 +284,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
}
|
||||
}
|
||||
|
||||
c.mr.reader = bufio.NewReader(c.conn)
|
||||
c.mr.cr = chunkreader.NewChunkReader(c.conn)
|
||||
|
||||
msg := newStartupMessage()
|
||||
|
||||
|
@ -844,9 +845,8 @@ func (c *Conn) Unlisten(channel string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WaitForNotification waits for a PostgreSQL notification for up to timeout.
|
||||
// If the timeout occurs it returns pgx.ErrNotificationTimeout
|
||||
func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) {
|
||||
// WaitForNotification waits for a PostgreSQL notification.
|
||||
func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
|
||||
// Return already received notification immediately
|
||||
if len(c.notifications) > 0 {
|
||||
notification := c.notifications[0]
|
||||
|
@ -854,97 +854,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||
return notification, nil
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
|
||||
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
|
||||
cancelFn()
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cancelFn()
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err = c.lock(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stopTime := time.Now().Add(timeout)
|
||||
|
||||
for {
|
||||
now := time.Now()
|
||||
|
||||
if now.After(stopTime) {
|
||||
return nil, ErrNotificationTimeout
|
||||
}
|
||||
|
||||
// If there has been no activity on this connection for a while send a nop message just to ensure
|
||||
// the connection is alive
|
||||
nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second)
|
||||
if nextEnsureAliveTime.Before(now) {
|
||||
// If the server can't respond to a nop in 15 seconds, assume it's dead
|
||||
err := c.conn.SetReadDeadline(now.Add(15 * time.Second))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = c.Exec("--;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.lastActivityTime = now
|
||||
}
|
||||
|
||||
var deadline time.Time
|
||||
if stopTime.Before(nextEnsureAliveTime) {
|
||||
deadline = stopTime
|
||||
} else {
|
||||
deadline = nextEnsureAliveTime
|
||||
}
|
||||
|
||||
notification, err := c.waitForNotification(deadline)
|
||||
if err != ErrNotificationTimeout {
|
||||
return notification, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
|
||||
var zeroTime time.Time
|
||||
|
||||
for {
|
||||
// Use SetReadDeadline to implement the timeout. SetReadDeadline will
|
||||
// cause operations to fail with a *net.OpError that has a Timeout()
|
||||
// of true. Because the normal pgx rxMsg path considers any error to
|
||||
// have potentially corrupted the state of the connection, it dies
|
||||
// on any errors. So to avoid timeout errors in rxMsg we set the
|
||||
// deadline and peek into the reader. If a timeout error occurs there
|
||||
// we don't break the pgx connection. If the Peek returns that data
|
||||
// is available then we turn off the read deadline before the rxMsg.
|
||||
err := c.conn.SetReadDeadline(deadline)
|
||||
t, r, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait until there is a byte available before continuing onto the normal msg reading path
|
||||
_, err = c.mr.reader.Peek(1)
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
||||
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
||||
return nil, ErrNotificationTimeout
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.conn.SetReadDeadline(zeroTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var t byte
|
||||
var r *msgReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -1114,7 +1057,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
|||
c.lastActivityTime = time.Now()
|
||||
|
||||
if c.shouldLog(LogLevelTrace) {
|
||||
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
|
||||
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody))
|
||||
}
|
||||
|
||||
return t, &c.mr, err
|
||||
|
@ -1236,11 +1179,11 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) {
|
|||
// wrong. So read the count, ignore it, and compute the proper value from
|
||||
// the size of the message.
|
||||
r.readInt16()
|
||||
parameterCount := r.msgBytesRemaining / 4
|
||||
parameterCount := len(r.msgBody[r.rp:]) / 4
|
||||
|
||||
parameters = make([]OID, 0, parameterCount)
|
||||
|
||||
for i := int32(0); i < parameterCount; i++ {
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
parameters = append(parameters, r.readOID())
|
||||
}
|
||||
return
|
||||
|
|
28
conn_test.go
28
conn_test.go
|
@ -1084,7 +1084,7 @@ func TestListenNotify(t *testing.T) {
|
|||
mustExec(t, notifier, "notify chat")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(time.Second)
|
||||
notification, err := listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1099,7 +1099,10 @@ func TestListenNotify(t *testing.T) {
|
|||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
notification, err = listener.WaitForNotification(0)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
cancelFn()
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1108,8 +1111,9 @@ func TestListenNotify(t *testing.T) {
|
|||
}
|
||||
|
||||
// when timeout occurs
|
||||
notification, err = listener.WaitForNotification(time.Millisecond)
|
||||
if err != pgx.ErrNotificationTimeout {
|
||||
ctx, _ = context.WithTimeout(context.Background(), time.Millisecond)
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
if notification != nil {
|
||||
|
@ -1118,7 +1122,7 @@ func TestListenNotify(t *testing.T) {
|
|||
|
||||
// listener can listen again after a timeout
|
||||
mustExec(t, notifier, "notify chat")
|
||||
notification, err = listener.WaitForNotification(time.Second)
|
||||
notification, err = listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1143,7 +1147,7 @@ func TestUnlistenSpecificChannel(t *testing.T) {
|
|||
mustExec(t, notifier, "notify unlisten_test")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(time.Second)
|
||||
notification, err := listener.WaitForNotification(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1163,8 +1167,10 @@ func TestUnlistenSpecificChannel(t *testing.T) {
|
|||
if rows.Err() != nil {
|
||||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
notification, err = listener.WaitForNotification(100 * time.Millisecond)
|
||||
if err != pgx.ErrNotificationTimeout {
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
notification, err = listener.WaitForNotification(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -1246,7 +1252,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
|
|||
// Notify self and WaitForNotification immediately
|
||||
mustExec(t, conn, "notify self")
|
||||
|
||||
notification, err := conn.WaitForNotification(time.Second)
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second)
|
||||
notification, err := conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
@ -1263,7 +1270,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
|
|||
t.Fatalf("Unexpected error on Query: %v", rows.Err())
|
||||
}
|
||||
|
||||
notification, err = conn.WaitForNotification(time.Second)
|
||||
ctx, _ = context.WithTimeout(context.Background(), time.Second)
|
||||
notification, err = conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
|
|
202
msg_reader.go
202
msg_reader.go
|
@ -1,26 +1,29 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
)
|
||||
|
||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
type msgReader struct {
|
||||
reader *bufio.Reader
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
shouldLog func(lvl int) bool
|
||||
cr *chunkreader.ChunkReader
|
||||
msgType byte
|
||||
msgBody []byte
|
||||
rp int // read position
|
||||
err error
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
shouldLog func(lvl int) bool
|
||||
}
|
||||
|
||||
// fatal tells rc that a Fatal error has occurred
|
||||
func (r *msgReader) fatal(err error) {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp)
|
||||
}
|
||||
r.err = err
|
||||
}
|
||||
|
@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) {
|
|||
return 0, r.err
|
||||
}
|
||||
|
||||
if r.msgBytesRemaining > 0 {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
n, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
r.msgBytesRemaining -= int32(n)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(5)
|
||||
header, err := r.cr.Next(5)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
|
@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
msgType := b[0]
|
||||
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
r.msgType = header[0]
|
||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
|
||||
// Try to preload bufio.Reader with entire message
|
||||
b, err = r.reader.Peek(5 + int(payloadSize))
|
||||
if err != nil && err != bufio.ErrBufferFull {
|
||||
r.msgBody, err = r.cr.Next(bodyLen)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.msgBytesRemaining = payloadSize
|
||||
r.reader.Discard(5)
|
||||
r.rp = 0
|
||||
|
||||
return msgType, nil
|
||||
return r.msgType, nil
|
||||
}
|
||||
|
||||
func (r *msgReader) readByte() byte {
|
||||
|
@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining--
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 1 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
b := r.msgBody[r.rp]
|
||||
r.rp++
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return b
|
||||
|
@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 2 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:]))
|
||||
r.rp += 2
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 4 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:]))
|
||||
r.rp += 4
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 2 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
n := binary.BigEndian.Uint16(r.msgBody[r.rp:])
|
||||
r.rp += 2
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 4 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
n := binary.BigEndian.Uint32(r.msgBody[r.rp:])
|
||||
r.rp += 4
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 8
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 8 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(8)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int64(binary.BigEndian.Uint64(b))
|
||||
|
||||
r.reader.Discard(8)
|
||||
n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:]))
|
||||
r.rp += 8
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -246,22 +188,17 @@ func (r *msgReader) readCString() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0)
|
||||
if nullIdx == -1 {
|
||||
r.fatal(errors.New("null terminated string not found"))
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= int32(len(b))
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
s := string(b[0 : len(b)-1])
|
||||
s := string(r.msgBody[r.rp : r.rp+nullIdx])
|
||||
r.rp += nullIdx + 1
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return s
|
||||
|
@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= countI32
|
||||
if r.msgBytesRemaining < 0 {
|
||||
count := int(countI32)
|
||||
|
||||
if len(r.msgBody)-r.rp < count {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
count := int(countI32)
|
||||
var s string
|
||||
|
||||
if r.reader.Buffered() >= count {
|
||||
buf, _ := r.reader.Peek(count)
|
||||
s = string(buf)
|
||||
r.reader.Discard(count)
|
||||
} else {
|
||||
buf := make([]byte, count)
|
||||
_, err := io.ReadFull(r.reader, buf)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
s = string(buf)
|
||||
}
|
||||
s := string(r.msgBody[r.rp : r.rp+count])
|
||||
r.rp += count
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// readBytes reads count bytes and returns as []byte
|
||||
func (r *msgReader) readBytes(count int32) []byte {
|
||||
func (r *msgReader) readBytes(countI32 int32) []byte {
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
count := int(countI32)
|
||||
|
||||
if len(r.msgBody)-r.rp < count {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return nil
|
||||
}
|
||||
|
||||
b := make([]byte, int(count))
|
||||
b := r.msgBody[r.rp : r.rp+count]
|
||||
r.rp += count
|
||||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return nil
|
||||
}
|
||||
r.cr.KeepLast()
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return b
|
||||
|
|
|
@ -1,189 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgmock/pgmsg"
|
||||
)
|
||||
|
||||
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
msgType byte
|
||||
payloadSize int32
|
||||
buffered bool
|
||||
}{
|
||||
{1, 50, true},
|
||||
{2, 0, true},
|
||||
{3, 500, true},
|
||||
{4, 1050, true},
|
||||
{5, 1500, true},
|
||||
{6, 1500, true},
|
||||
{7, 4000, true},
|
||||
{8, 24000, false},
|
||||
{9, 4000, true},
|
||||
{1, 1500, true},
|
||||
{2, 0, true},
|
||||
{3, 500, true},
|
||||
{4, 1050, true},
|
||||
{5, 1500, true},
|
||||
{6, 1500, true},
|
||||
{7, 4000, true},
|
||||
{8, 14000, false},
|
||||
{9, 0, true},
|
||||
{1, 500, true},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
var bigEndian pgmsg.BigEndianBuf
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
_, err = conn.Write([]byte{tt.msgType})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
payload := make([]byte, int(tt.payloadSize))
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mr := &msgReader{
|
||||
reader: bufio.NewReader(conn),
|
||||
shouldLog: func(int) bool { return false },
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
msgType, err := mr.rxMsg()
|
||||
if err != nil {
|
||||
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if msgType != tt.msgType {
|
||||
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
|
||||
}
|
||||
|
||||
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
|
||||
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
testCount := 10000
|
||||
|
||||
go func() {
|
||||
var bigEndian pgmsg.BigEndianBuf
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for i := 0; i < testCount; i++ {
|
||||
msgType := byte(i)
|
||||
|
||||
_, err = conn.Write([]byte{msgType})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msgSize := i % 4000
|
||||
|
||||
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
payload := make([]byte, msgSize)
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mr := &msgReader{
|
||||
reader: bufio.NewReader(conn),
|
||||
shouldLog: func(int) bool { return false },
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
|
||||
i := 0
|
||||
for {
|
||||
msgType, err := mr.rxMsg()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
continue
|
||||
} else {
|
||||
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
expectedMsgType := byte(i)
|
||||
if msgType != expectedMsgType {
|
||||
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
|
||||
}
|
||||
|
||||
expectedMsgSize := i % 4000
|
||||
payload := mr.readBytes(mr.msgBytesRemaining)
|
||||
if mr.err != nil {
|
||||
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
|
||||
}
|
||||
if len(payload) != expectedMsgSize {
|
||||
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
|
||||
}
|
||||
|
||||
i++
|
||||
if i == testCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,9 +1,9 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -234,7 +234,7 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
|
|||
walStart := reader.readInt64()
|
||||
serverWalEnd := reader.readInt64()
|
||||
serverTime := reader.readInt64()
|
||||
walData := reader.readBytes(reader.msgBytesRemaining)
|
||||
walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp))
|
||||
walMessage := WalMessage{WalStart: uint64(walStart),
|
||||
ServerWalEnd: uint64(serverWalEnd),
|
||||
ServerTime: uint64(serverTime),
|
||||
|
@ -261,47 +261,23 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
|
|||
return
|
||||
}
|
||||
|
||||
// Wait for a single replication message up to timeout time.
|
||||
// Wait for a single replication message.
|
||||
//
|
||||
// Properly using this requires some knowledge of the postgres replication mechanisms,
|
||||
// as the client can receive both WAL data (the ultimate payload) and server heartbeat
|
||||
// updates. The caller also must send standby status updates in order to keep the connection
|
||||
// alive and working.
|
||||
//
|
||||
// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified
|
||||
// duration.
|
||||
func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) {
|
||||
var zeroTime time.Time
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
// Use SetReadDeadline to implement the timeout. SetReadDeadline will
|
||||
// cause operations to fail with a *net.OpError that has a Timeout()
|
||||
// of true. Because the normal pgx rxMsg path considers any error to
|
||||
// have potentially corrupted the state of the connection, it dies
|
||||
// on any errors. So to avoid timeout errors in rxMsg we set the
|
||||
// deadline and peek into the reader. If a timeout error occurs there
|
||||
// we don't break the pgx connection. If the Peek returns that data
|
||||
// is available then we turn off the read deadline before the rxMsg.
|
||||
err = rc.c.conn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait until there is a byte available before continuing onto the normal msg reading path
|
||||
_, err = rc.c.mr.reader.Peek(1)
|
||||
if err != nil {
|
||||
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
||||
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
||||
return nil, ErrNotificationTimeout
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = rc.c.conn.SetReadDeadline(zeroTime)
|
||||
// This returns the context error when there is no replication message before
|
||||
// the context is canceled.
|
||||
func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) {
|
||||
err = rc.c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = rc.c.termContext(err)
|
||||
}()
|
||||
|
||||
return rc.readReplicationMessage()
|
||||
}
|
||||
|
@ -401,12 +377,14 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
|
|||
return
|
||||
}
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout)
|
||||
|
||||
// The first replication message that comes back here will be (in a success case)
|
||||
// a empty CopyBoth that is (apparently) sent as the confirmation that the replication has
|
||||
// started. This call will either return nil, nil or if it returns an error
|
||||
// that indicates the start replication command failed
|
||||
var r *ReplicationMessage
|
||||
r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout)
|
||||
r, err = rc.WaitForReplicationMessage(ctx)
|
||||
if err != nil && r != nil {
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, "Unxpected replication message %v", r)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"reflect"
|
||||
|
@ -88,11 +89,10 @@ func TestSimpleReplicationConnection(t *testing.T) {
|
|||
for {
|
||||
var message *pgx.ReplicationMessage
|
||||
|
||||
message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second))
|
||||
if err != nil {
|
||||
if err != pgx.ErrNotificationTimeout {
|
||||
t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err))
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second)
|
||||
message, err = replicationConn.WaitForReplicationMessage(ctx)
|
||||
if err != nil && err != context.DeadlineExceeded {
|
||||
t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err))
|
||||
}
|
||||
if message != nil {
|
||||
if message.WalMessage != nil {
|
||||
|
|
|
@ -244,8 +244,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
|
|||
return err
|
||||
}
|
||||
|
||||
_, err = conn.WaitForNotification(100 * time.Millisecond)
|
||||
if err == pgx.ErrNotificationTimeout {
|
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_, err = conn.WaitForNotification(ctx)
|
||||
if err == context.DeadlineExceeded {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue