diff --git a/conn.go b/conn.go index 07422a32..a8b0b22c 100644 --- a/conn.go +++ b/conn.go @@ -1,7 +1,6 @@ package pgx import ( - "bufio" "crypto/md5" "crypto/tls" "encoding/binary" @@ -20,6 +19,8 @@ import ( "strings" "sync/atomic" "time" + + "github.com/jackc/pgx/chunkreader" ) const ( @@ -283,7 +284,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.mr.reader = bufio.NewReader(c.conn) + c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -844,9 +845,8 @@ func (c *Conn) Unlisten(channel string) error { return nil } -// WaitForNotification waits for a PostgreSQL notification for up to timeout. -// If the timeout occurs it returns pgx.ErrNotificationTimeout -func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) { +// WaitForNotification waits for a PostgreSQL notification. +func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) { // Return already received notification immediately if len(c.notifications) > 0 { notification := c.notifications[0] @@ -854,97 +854,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } - ctx, cancelFn := context.WithTimeout(context.Background(), timeout) - if err := c.waitForPreviousCancelQuery(ctx); err != nil { - cancelFn() + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { return nil, err } - cancelFn() + + err = c.initContext(ctx) + if err != nil { + return nil, err + } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return nil, err + } + defer func() { + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() if err := c.ensureConnectionReadyForQuery(); err != nil { return nil, err } - stopTime := time.Now().Add(timeout) - for { - now := time.Now() - - if now.After(stopTime) { - return nil, ErrNotificationTimeout - } - - // If there has been no activity on this connection for a while send a nop message just to ensure - // the connection is alive - nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second) - if nextEnsureAliveTime.Before(now) { - // If the server can't respond to a nop in 15 seconds, assume it's dead - err := c.conn.SetReadDeadline(now.Add(15 * time.Second)) - if err != nil { - return nil, err - } - - _, err = c.Exec("--;") - if err != nil { - return nil, err - } - - c.lastActivityTime = now - } - - var deadline time.Time - if stopTime.Before(nextEnsureAliveTime) { - deadline = stopTime - } else { - deadline = nextEnsureAliveTime - } - - notification, err := c.waitForNotification(deadline) - if err != ErrNotificationTimeout { - return notification, err - } - } -} - -func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { - var zeroTime time.Time - - for { - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err := c.conn.SetReadDeadline(deadline) + t, r, err := c.rxMsg() if err != nil { return nil, err } - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.mr.reader.Peek(1) + err = c.processContextFreeMsg(t, r) if err != nil { - c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = c.conn.SetReadDeadline(zeroTime) - if err != nil { - return nil, err - } - - var t byte - var r *msgReader - if t, r, err = c.rxMsg(); err == nil { - if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err - } - } else { return nil, err } @@ -1114,7 +1057,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { c.lastActivityTime = time.Now() if c.shouldLog(LogLevelTrace) { - c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) + c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody)) } return t, &c.mr, err @@ -1236,11 +1179,11 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { // wrong. So read the count, ignore it, and compute the proper value from // the size of the message. r.readInt16() - parameterCount := r.msgBytesRemaining / 4 + parameterCount := len(r.msgBody[r.rp:]) / 4 parameters = make([]OID, 0, parameterCount) - for i := int32(0); i < parameterCount; i++ { + for i := 0; i < parameterCount; i++ { parameters = append(parameters, r.readOID()) } return diff --git a/conn_test.go b/conn_test.go index a8398507..63b486a6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1084,7 +1084,7 @@ func TestListenNotify(t *testing.T) { mustExec(t, notifier, "notify chat") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1099,7 +1099,10 @@ func TestListenNotify(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(0) + + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + notification, err = listener.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1108,8 +1111,9 @@ func TestListenNotify(t *testing.T) { } // when timeout occurs - notification, err = listener.WaitForNotification(time.Millisecond) - if err != pgx.ErrNotificationTimeout { + ctx, _ = context.WithTimeout(context.Background(), time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } if notification != nil { @@ -1118,7 +1122,7 @@ func TestListenNotify(t *testing.T) { // listener can listen again after a timeout mustExec(t, notifier, "notify chat") - notification, err = listener.WaitForNotification(time.Second) + notification, err = listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1143,7 +1147,7 @@ func TestUnlistenSpecificChannel(t *testing.T) { mustExec(t, notifier, "notify unlisten_test") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1163,8 +1167,10 @@ func TestUnlistenSpecificChannel(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(100 * time.Millisecond) - if err != pgx.ErrNotificationTimeout { + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } } @@ -1246,7 +1252,8 @@ func TestListenNotifySelfNotification(t *testing.T) { // Notify self and WaitForNotification immediately mustExec(t, conn, "notify self") - notification, err := conn.WaitForNotification(time.Second) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + notification, err := conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1263,7 +1270,8 @@ func TestListenNotifySelfNotification(t *testing.T) { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = conn.WaitForNotification(time.Second) + ctx, _ = context.WithTimeout(context.Background(), time.Second) + notification, err = conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } diff --git a/msg_reader.go b/msg_reader.go index f507c198..53e944bb 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -1,26 +1,29 @@ package pgx import ( - "bufio" + "bytes" "encoding/binary" "errors" - "io" "net" + + "github.com/jackc/pgx/chunkreader" ) // msgReader is a helper that reads values from a PostgreSQL message. type msgReader struct { - reader *bufio.Reader - msgBytesRemaining int32 - err error - log func(lvl int, msg string, ctx ...interface{}) - shouldLog func(lvl int) bool + cr *chunkreader.ChunkReader + msgType byte + msgBody []byte + rp int // read position + err error + log func(lvl int, msg string, ctx ...interface{}) + shouldLog func(lvl int) bool } // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp) } r.err = err } @@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) { return 0, r.err } - if r.msgBytesRemaining > 0 { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) - } - - n, err := r.reader.Discard(int(r.msgBytesRemaining)) - r.msgBytesRemaining -= int32(n) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - } - - b, err := r.reader.Peek(5) + header, err := r.cr.Next(5) if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) @@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) { return 0, err } - msgType := b[0] - payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + r.msgType = header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - // Try to preload bufio.Reader with entire message - b, err = r.reader.Peek(5 + int(payloadSize)) - if err != nil && err != bufio.ErrBufferFull { + r.msgBody, err = r.cr.Next(bodyLen) + if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) } return 0, err } - r.msgBytesRemaining = payloadSize - r.reader.Discard(5) + r.rp = 0 - return msgType, nil + return r.msgType, nil } func (r *msgReader) readByte() byte { @@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte { return 0 } - r.msgBytesRemaining-- - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 1 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.ReadByte() - if err != nil { - r.fatal(err) - return 0 - } + b := r.msgBody[r.rp] + r.rp++ if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp) } return b @@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 { return 0 } - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := int16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) + n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:])) + r.rp += 2 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 { return 0 } - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := int32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) + n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:])) + r.rp += 4 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 { return 0 } - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) + n := binary.BigEndian.Uint16(r.msgBody[r.rp:]) + r.rp += 2 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 { return 0 } - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) + n := binary.BigEndian.Uint32(r.msgBody[r.rp:]) + r.rp += 4 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 { return 0 } - r.msgBytesRemaining -= 8 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 8 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(8) - if err != nil { - r.fatal(err) - return 0 - } - - n := int64(binary.BigEndian.Uint64(b)) - - r.reader.Discard(8) + n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:])) + r.rp += 8 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -246,22 +188,17 @@ func (r *msgReader) readCString() string { return "" } - b, err := r.reader.ReadBytes(0) - if err != nil { - r.fatal(err) + nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0) + if nullIdx == -1 { + r.fatal(errors.New("null terminated string not found")) return "" } - r.msgBytesRemaining -= int32(len(b)) - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return "" - } - - s := string(b[0 : len(b)-1]) + s := string(r.msgBody[r.rp : r.rp+nullIdx]) + r.rp += nullIdx + 1 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s @@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string { return "" } - r.msgBytesRemaining -= countI32 - if r.msgBytesRemaining < 0 { + count := int(countI32) + + if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return "" } - count := int(countI32) - var s string - - if r.reader.Buffered() >= count { - buf, _ := r.reader.Peek(count) - s = string(buf) - r.reader.Discard(count) - } else { - buf := make([]byte, count) - _, err := io.ReadFull(r.reader, buf) - if err != nil { - r.fatal(err) - return "" - } - s = string(buf) - } + s := string(r.msgBody[r.rp : r.rp+count]) + r.rp += count if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s } // readBytes reads count bytes and returns as []byte -func (r *msgReader) readBytes(count int32) []byte { +func (r *msgReader) readBytes(countI32 int32) []byte { if r.err != nil { return nil } - r.msgBytesRemaining -= count - if r.msgBytesRemaining < 0 { + count := int(countI32) + + if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return nil } - b := make([]byte, int(count)) + b := r.msgBody[r.rp : r.rp+count] + r.rp += count - _, err := io.ReadFull(r.reader, b) - if err != nil { - r.fatal(err) - return nil - } + r.cr.KeepLast() if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp) } return b diff --git a/msg_reader_test.go b/msg_reader_test.go deleted file mode 100644 index 2bbd53c9..00000000 --- a/msg_reader_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package pgx - -import ( - "bufio" - "net" - "testing" - "time" - - "github.com/jackc/pgmock/pgmsg" -) - -func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { - t.Parallel() - - tests := []struct { - msgType byte - payloadSize int32 - buffered bool - }{ - {1, 50, true}, - {2, 0, true}, - {3, 500, true}, - {4, 1050, true}, - {5, 1500, true}, - {6, 1500, true}, - {7, 4000, true}, - {8, 24000, false}, - {9, 4000, true}, - {1, 1500, true}, - {2, 0, true}, - {3, 500, true}, - {4, 1050, true}, - {5, 1500, true}, - {6, 1500, true}, - {7, 4000, true}, - {8, 14000, false}, - {9, 0, true}, - {1, 500, true}, - } - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - go func() { - var bigEndian pgmsg.BigEndianBuf - - conn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - for _, tt := range tests { - _, err = conn.Write([]byte{tt.msgType}) - if err != nil { - t.Fatal(err) - } - - _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) - if err != nil { - t.Fatal(err) - } - - payload := make([]byte, int(tt.payloadSize)) - _, err = conn.Write(payload) - if err != nil { - t.Fatal(err) - } - } - }() - - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - mr := &msgReader{ - reader: bufio.NewReader(conn), - shouldLog: func(int) bool { return false }, - } - - for i, tt := range tests { - msgType, err := mr.rxMsg() - if err != nil { - t.Fatalf("%d. Unexpected error: %v", i, err) - } - - if msgType != tt.msgType { - t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) - } - - if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { - t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) - } - } -} - -func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { - t.Parallel() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - testCount := 10000 - - go func() { - var bigEndian pgmsg.BigEndianBuf - - conn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - for i := 0; i < testCount; i++ { - msgType := byte(i) - - _, err = conn.Write([]byte{msgType}) - if err != nil { - t.Fatal(err) - } - - msgSize := i % 4000 - - _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) - if err != nil { - t.Fatal(err) - } - - payload := make([]byte, msgSize) - _, err = conn.Write(payload) - if err != nil { - t.Fatal(err) - } - } - }() - - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - mr := &msgReader{ - reader: bufio.NewReader(conn), - shouldLog: func(int) bool { return false }, - } - - conn.SetReadDeadline(time.Now().Add(time.Millisecond)) - - i := 0 - for { - msgType, err := mr.rxMsg() - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - conn.SetReadDeadline(time.Now().Add(time.Millisecond)) - continue - } else { - t.Fatalf("%d. Unexpected error: %v", i, err) - } - } - - expectedMsgType := byte(i) - if msgType != expectedMsgType { - t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) - } - - expectedMsgSize := i % 4000 - payload := mr.readBytes(mr.msgBytesRemaining) - if mr.err != nil { - t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) - } - if len(payload) != expectedMsgSize { - t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) - } - - i++ - if i == testCount { - break - } - } -} diff --git a/replication.go b/replication.go index 0acc9df9..a3e58fa3 100644 --- a/replication.go +++ b/replication.go @@ -1,9 +1,9 @@ package pgx import ( + "context" "errors" "fmt" - "net" "time" ) @@ -234,7 +234,7 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err walStart := reader.readInt64() serverWalEnd := reader.readInt64() serverTime := reader.readInt64() - walData := reader.readBytes(reader.msgBytesRemaining) + walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp)) walMessage := WalMessage{WalStart: uint64(walStart), ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), @@ -261,47 +261,23 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err return } -// Wait for a single replication message up to timeout time. +// Wait for a single replication message. // // Properly using this requires some knowledge of the postgres replication mechanisms, // as the client can receive both WAL data (the ultimate payload) and server heartbeat // updates. The caller also must send standby status updates in order to keep the connection // alive and working. // -// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified -// duration. -func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { - var zeroTime time.Time - - deadline := time.Now().Add(timeout) - - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err = rc.c.conn.SetReadDeadline(deadline) - if err != nil { - return nil, err - } - - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.mr.reader.Peek(1) - if err != nil { - rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = rc.c.conn.SetReadDeadline(zeroTime) +// This returns the context error when there is no replication message before +// the context is canceled. +func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) { + err = rc.c.initContext(ctx) if err != nil { return nil, err } + defer func() { + err = rc.c.termContext(err) + }() return rc.readReplicationMessage() } @@ -401,12 +377,14 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti return } + ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) + // The first replication message that comes back here will be (in a success case) // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has // started. This call will either return nil, nil or if it returns an error // that indicates the start replication command failed var r *ReplicationMessage - r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) + r, err = rc.WaitForReplicationMessage(ctx) if err != nil && r != nil { if rc.c.shouldLog(LogLevelError) { rc.c.log(LogLevelError, "Unxpected replication message %v", r) diff --git a/replication_test.go b/replication_test.go index 4f810c78..2c2d0af5 100644 --- a/replication_test.go +++ b/replication_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "fmt" "github.com/jackc/pgx" "reflect" @@ -88,11 +89,10 @@ func TestSimpleReplicationConnection(t *testing.T) { for { var message *pgx.ReplicationMessage - message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second)) - if err != nil { - if err != pgx.ErrNotificationTimeout { - t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) - } + ctx, _ := context.WithTimeout(context.Background(), time.Second) + message, err = replicationConn.WaitForReplicationMessage(ctx) + if err != nil && err != context.DeadlineExceeded { + t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) } if message != nil { if message.WalMessage != nil { diff --git a/stress_test.go b/stress_test.go index 72d48a5c..82979fd6 100644 --- a/stress_test.go +++ b/stress_test.go @@ -244,8 +244,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { return err } - _, err = conn.WaitForNotification(100 * time.Millisecond) - if err == pgx.ErrNotificationTimeout { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + _, err = conn.WaitForNotification(ctx) + if err == context.DeadlineExceeded { return nil } return err