msgReader implemented in terms of ChunkReader

This should substantially reduce memory allocations and memory copies.

It also means that PostgreSQL messages are always entirely buffered in memory
before processing begins. This simplifies the message processing code.

In particular, Conn.WaitForNotification is dramatically simplified by this
change.
v3-experimental
Jack Christensen 2017-02-13 20:41:58 -06:00
parent 84802ece05
commit 11b82b3ca4
7 changed files with 130 additions and 467 deletions

115
conn.go
View File

@ -1,7 +1,6 @@
package pgx
import (
"bufio"
"crypto/md5"
"crypto/tls"
"encoding/binary"
@ -20,6 +19,8 @@ import (
"strings"
"sync/atomic"
"time"
"github.com/jackc/pgx/chunkreader"
)
const (
@ -283,7 +284,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
}
c.mr.reader = bufio.NewReader(c.conn)
c.mr.cr = chunkreader.NewChunkReader(c.conn)
msg := newStartupMessage()
@ -844,9 +845,8 @@ func (c *Conn) Unlisten(channel string) error {
return nil
}
// WaitForNotification waits for a PostgreSQL notification for up to timeout.
// If the timeout occurs it returns pgx.ErrNotificationTimeout
func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) {
// WaitForNotification waits for a PostgreSQL notification.
func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
// Return already received notification immediately
if len(c.notifications) > 0 {
notification := c.notifications[0]
@ -854,97 +854,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
return notification, nil
}
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
cancelFn()
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
cancelFn()
err = c.initContext(ctx)
if err != nil {
return nil, err
}
defer func() {
err = c.termContext(err)
}()
if err = c.lock(); err != nil {
return nil, err
}
defer func() {
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
stopTime := time.Now().Add(timeout)
for {
now := time.Now()
if now.After(stopTime) {
return nil, ErrNotificationTimeout
}
// If there has been no activity on this connection for a while send a nop message just to ensure
// the connection is alive
nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second)
if nextEnsureAliveTime.Before(now) {
// If the server can't respond to a nop in 15 seconds, assume it's dead
err := c.conn.SetReadDeadline(now.Add(15 * time.Second))
if err != nil {
return nil, err
}
_, err = c.Exec("--;")
if err != nil {
return nil, err
}
c.lastActivityTime = now
}
var deadline time.Time
if stopTime.Before(nextEnsureAliveTime) {
deadline = stopTime
} else {
deadline = nextEnsureAliveTime
}
notification, err := c.waitForNotification(deadline)
if err != ErrNotificationTimeout {
return notification, err
}
}
}
func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
var zeroTime time.Time
for {
// Use SetReadDeadline to implement the timeout. SetReadDeadline will
// cause operations to fail with a *net.OpError that has a Timeout()
// of true. Because the normal pgx rxMsg path considers any error to
// have potentially corrupted the state of the connection, it dies
// on any errors. So to avoid timeout errors in rxMsg we set the
// deadline and peek into the reader. If a timeout error occurs there
// we don't break the pgx connection. If the Peek returns that data
// is available then we turn off the read deadline before the rxMsg.
err := c.conn.SetReadDeadline(deadline)
t, r, err := c.rxMsg()
if err != nil {
return nil, err
}
// Wait until there is a byte available before continuing onto the normal msg reading path
_, err = c.mr.reader.Peek(1)
err = c.processContextFreeMsg(t, r)
if err != nil {
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
if err, ok := err.(*net.OpError); ok && err.Timeout() {
return nil, ErrNotificationTimeout
}
return nil, err
}
err = c.conn.SetReadDeadline(zeroTime)
if err != nil {
return nil, err
}
var t byte
var r *msgReader
if t, r, err = c.rxMsg(); err == nil {
if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err
}
} else {
return nil, err
}
@ -1114,7 +1057,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
c.lastActivityTime = time.Now()
if c.shouldLog(LogLevelTrace) {
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody))
}
return t, &c.mr, err
@ -1236,11 +1179,11 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) {
// wrong. So read the count, ignore it, and compute the proper value from
// the size of the message.
r.readInt16()
parameterCount := r.msgBytesRemaining / 4
parameterCount := len(r.msgBody[r.rp:]) / 4
parameters = make([]OID, 0, parameterCount)
for i := int32(0); i < parameterCount; i++ {
for i := 0; i < parameterCount; i++ {
parameters = append(parameters, r.readOID())
}
return

View File

@ -1084,7 +1084,7 @@ func TestListenNotify(t *testing.T) {
mustExec(t, notifier, "notify chat")
// when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(time.Second)
notification, err := listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1099,7 +1099,10 @@ func TestListenNotify(t *testing.T) {
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
notification, err = listener.WaitForNotification(0)
ctx, cancelFn := context.WithCancel(context.Background())
cancelFn()
notification, err = listener.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1108,8 +1111,9 @@ func TestListenNotify(t *testing.T) {
}
// when timeout occurs
notification, err = listener.WaitForNotification(time.Millisecond)
if err != pgx.ErrNotificationTimeout {
ctx, _ = context.WithTimeout(context.Background(), time.Millisecond)
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
}
if notification != nil {
@ -1118,7 +1122,7 @@ func TestListenNotify(t *testing.T) {
// listener can listen again after a timeout
mustExec(t, notifier, "notify chat")
notification, err = listener.WaitForNotification(time.Second)
notification, err = listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1143,7 +1147,7 @@ func TestUnlistenSpecificChannel(t *testing.T) {
mustExec(t, notifier, "notify unlisten_test")
// when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(time.Second)
notification, err := listener.WaitForNotification(context.Background())
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1163,8 +1167,10 @@ func TestUnlistenSpecificChannel(t *testing.T) {
if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
notification, err = listener.WaitForNotification(100 * time.Millisecond)
if err != pgx.ErrNotificationTimeout {
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
}
}
@ -1246,7 +1252,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
// Notify self and WaitForNotification immediately
mustExec(t, conn, "notify self")
notification, err := conn.WaitForNotification(time.Second)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
notification, err := conn.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}
@ -1263,7 +1270,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
t.Fatalf("Unexpected error on Query: %v", rows.Err())
}
notification, err = conn.WaitForNotification(time.Second)
ctx, _ = context.WithTimeout(context.Background(), time.Second)
notification, err = conn.WaitForNotification(ctx)
if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
}

View File

@ -1,26 +1,29 @@
package pgx
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"github.com/jackc/pgx/chunkreader"
)
// msgReader is a helper that reads values from a PostgreSQL message.
type msgReader struct {
reader *bufio.Reader
msgBytesRemaining int32
err error
log func(lvl int, msg string, ctx ...interface{})
shouldLog func(lvl int) bool
cr *chunkreader.ChunkReader
msgType byte
msgBody []byte
rp int // read position
err error
log func(lvl int, msg string, ctx ...interface{})
shouldLog func(lvl int) bool
}
// fatal tells rc that a Fatal error has occurred
func (r *msgReader) fatal(err error) {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp)
}
r.err = err
}
@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) {
return 0, r.err
}
if r.msgBytesRemaining > 0 {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}
n, err := r.reader.Discard(int(r.msgBytesRemaining))
r.msgBytesRemaining -= int32(n)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
}
b, err := r.reader.Peek(5)
header, err := r.cr.Next(5)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) {
return 0, err
}
msgType := b[0]
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
r.msgType = header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
// Try to preload bufio.Reader with entire message
b, err = r.reader.Peek(5 + int(payloadSize))
if err != nil && err != bufio.ErrBufferFull {
r.msgBody, err = r.cr.Next(bodyLen)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
r.msgBytesRemaining = payloadSize
r.reader.Discard(5)
r.rp = 0
return msgType, nil
return r.msgType, nil
}
func (r *msgReader) readByte() byte {
@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte {
return 0
}
r.msgBytesRemaining--
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 1 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.ReadByte()
if err != nil {
r.fatal(err)
return 0
}
b := r.msgBody[r.rp]
r.rp++
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp)
}
return b
@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 2 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := int16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:]))
r.rp += 2
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 4 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := int32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:]))
r.rp += 4
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 2 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := uint16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
n := binary.BigEndian.Uint16(r.msgBody[r.rp:])
r.rp += 2
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 4 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := uint32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
n := binary.BigEndian.Uint32(r.msgBody[r.rp:])
r.rp += 4
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 {
return 0
}
r.msgBytesRemaining -= 8
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 8 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(8)
if err != nil {
r.fatal(err)
return 0
}
n := int64(binary.BigEndian.Uint64(b))
r.reader.Discard(8)
n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:]))
r.rp += 8
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@ -246,22 +188,17 @@ func (r *msgReader) readCString() string {
return ""
}
b, err := r.reader.ReadBytes(0)
if err != nil {
r.fatal(err)
nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0)
if nullIdx == -1 {
r.fatal(errors.New("null terminated string not found"))
return ""
}
r.msgBytesRemaining -= int32(len(b))
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return ""
}
s := string(b[0 : len(b)-1])
s := string(r.msgBody[r.rp : r.rp+nullIdx])
r.rp += nullIdx + 1
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp)
}
return s
@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string {
return ""
}
r.msgBytesRemaining -= countI32
if r.msgBytesRemaining < 0 {
count := int(countI32)
if len(r.msgBody)-r.rp < count {
r.fatal(errors.New("read past end of message"))
return ""
}
count := int(countI32)
var s string
if r.reader.Buffered() >= count {
buf, _ := r.reader.Peek(count)
s = string(buf)
r.reader.Discard(count)
} else {
buf := make([]byte, count)
_, err := io.ReadFull(r.reader, buf)
if err != nil {
r.fatal(err)
return ""
}
s = string(buf)
}
s := string(r.msgBody[r.rp : r.rp+count])
r.rp += count
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp)
}
return s
}
// readBytes reads count bytes and returns as []byte
func (r *msgReader) readBytes(count int32) []byte {
func (r *msgReader) readBytes(countI32 int32) []byte {
if r.err != nil {
return nil
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
count := int(countI32)
if len(r.msgBody)-r.rp < count {
r.fatal(errors.New("read past end of message"))
return nil
}
b := make([]byte, int(count))
b := r.msgBody[r.rp : r.rp+count]
r.rp += count
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.fatal(err)
return nil
}
r.cr.KeepLast()
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp)
}
return b

View File

@ -1,189 +0,0 @@
package pgx
import (
"bufio"
"net"
"testing"
"time"
"github.com/jackc/pgmock/pgmsg"
)
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
t.Parallel()
tests := []struct {
msgType byte
payloadSize int32
buffered bool
}{
{1, 50, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 24000, false},
{9, 4000, true},
{1, 1500, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 14000, false},
{9, 0, true},
{1, 500, true},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for _, tt := range tests {
_, err = conn.Write([]byte{tt.msgType})
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, int(tt.payloadSize))
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
for i, tt := range tests {
msgType, err := mr.rxMsg()
if err != nil {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
if msgType != tt.msgType {
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
}
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
}
}
}
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
testCount := 10000
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for i := 0; i < testCount; i++ {
msgType := byte(i)
_, err = conn.Write([]byte{msgType})
if err != nil {
t.Fatal(err)
}
msgSize := i % 4000
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, msgSize)
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
i := 0
for {
msgType, err := mr.rxMsg()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
continue
} else {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
}
expectedMsgType := byte(i)
if msgType != expectedMsgType {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
}
expectedMsgSize := i % 4000
payload := mr.readBytes(mr.msgBytesRemaining)
if mr.err != nil {
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
}
if len(payload) != expectedMsgSize {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
}
i++
if i == testCount {
break
}
}
}

View File

@ -1,9 +1,9 @@
package pgx
import (
"context"
"errors"
"fmt"
"net"
"time"
)
@ -234,7 +234,7 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
walStart := reader.readInt64()
serverWalEnd := reader.readInt64()
serverTime := reader.readInt64()
walData := reader.readBytes(reader.msgBytesRemaining)
walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp))
walMessage := WalMessage{WalStart: uint64(walStart),
ServerWalEnd: uint64(serverWalEnd),
ServerTime: uint64(serverTime),
@ -261,47 +261,23 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
return
}
// Wait for a single replication message up to timeout time.
// Wait for a single replication message.
//
// Properly using this requires some knowledge of the postgres replication mechanisms,
// as the client can receive both WAL data (the ultimate payload) and server heartbeat
// updates. The caller also must send standby status updates in order to keep the connection
// alive and working.
//
// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified
// duration.
func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) {
var zeroTime time.Time
deadline := time.Now().Add(timeout)
// Use SetReadDeadline to implement the timeout. SetReadDeadline will
// cause operations to fail with a *net.OpError that has a Timeout()
// of true. Because the normal pgx rxMsg path considers any error to
// have potentially corrupted the state of the connection, it dies
// on any errors. So to avoid timeout errors in rxMsg we set the
// deadline and peek into the reader. If a timeout error occurs there
// we don't break the pgx connection. If the Peek returns that data
// is available then we turn off the read deadline before the rxMsg.
err = rc.c.conn.SetReadDeadline(deadline)
if err != nil {
return nil, err
}
// Wait until there is a byte available before continuing onto the normal msg reading path
_, err = rc.c.mr.reader.Peek(1)
if err != nil {
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
if err, ok := err.(*net.OpError); ok && err.Timeout() {
return nil, ErrNotificationTimeout
}
return nil, err
}
err = rc.c.conn.SetReadDeadline(zeroTime)
// This returns the context error when there is no replication message before
// the context is canceled.
func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) {
err = rc.c.initContext(ctx)
if err != nil {
return nil, err
}
defer func() {
err = rc.c.termContext(err)
}()
return rc.readReplicationMessage()
}
@ -401,12 +377,14 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
return
}
ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout)
// The first replication message that comes back here will be (in a success case)
// a empty CopyBoth that is (apparently) sent as the confirmation that the replication has
// started. This call will either return nil, nil or if it returns an error
// that indicates the start replication command failed
var r *ReplicationMessage
r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout)
r, err = rc.WaitForReplicationMessage(ctx)
if err != nil && r != nil {
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unxpected replication message %v", r)

View File

@ -1,6 +1,7 @@
package pgx_test
import (
"context"
"fmt"
"github.com/jackc/pgx"
"reflect"
@ -88,11 +89,10 @@ func TestSimpleReplicationConnection(t *testing.T) {
for {
var message *pgx.ReplicationMessage
message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second))
if err != nil {
if err != pgx.ErrNotificationTimeout {
t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err))
}
ctx, _ := context.WithTimeout(context.Background(), time.Second)
message, err = replicationConn.WaitForReplicationMessage(ctx)
if err != nil && err != context.DeadlineExceeded {
t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err))
}
if message != nil {
if message.WalMessage != nil {

View File

@ -244,8 +244,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
return err
}
_, err = conn.WaitForNotification(100 * time.Millisecond)
if err == pgx.ErrNotificationTimeout {
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
_, err = conn.WaitForNotification(ctx)
if err == context.DeadlineExceeded {
return nil
}
return err