From a0bfe4eab87534e1f1c8351c470a5de93f6bc5fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 25 Apr 2014 10:03:11 -0600 Subject: [PATCH] Fix WaitForNotification when it times out --- connection.go | 64 +++++++++++++++++++++++++++++++++------------- connection_test.go | 13 ++++++++-- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/connection.go b/connection.go index 65007d26..baaa92a2 100644 --- a/connection.go +++ b/connection.go @@ -91,6 +91,8 @@ func (e ProtocolError) Error() string { return string(e) } +var NotificationTimeoutError = errors.New("Notification Timeout") + // Connect establishes a connection with a PostgreSQL server using parameters. One // of parameters.Socket or parameters.Host must be specified. parameters.User // will default to the OS user name. Other parameters fields are optional. @@ -542,35 +544,61 @@ func (c *Connection) Listen(channel string) (err error) { return } -// WaitForNotification waits for a PostgreSQL notification for up to timeout -func (c *Connection) WaitForNotification(timeout time.Duration) (notification *Notification, err error) { - err = c.conn.SetReadDeadline(time.Now().Add(timeout)) - if err != nil { - return +// WaitForNotification waits for a PostgreSQL notification for up to timeout. +// If the timeout occurs it returns pgx.NotificationTimeoutError +func (c *Connection) WaitForNotification(timeout time.Duration) (*Notification, error) { + if len(c.notifications) > 0 { + notification := c.notifications[0] + c.notifications = c.notifications[1:] + return notification, nil } - defer func() { - var zeroTime time.Time - e := c.conn.SetReadDeadline(zeroTime) - if err == nil && e != nil { - err = e - } - }() + + var zeroTime time.Time + stopTime := time.Now().Add(timeout) for { - if len(c.notifications) > 0 { - notification = c.notifications[0] - c.notifications = c.notifications[1:] - return + // Use SetReadDeadline to implement the timeout. SetReadDeadline will + // cause operations to fail with a *net.OpError that has a Timeout() + // of true. Because the normal pgx rxMsg path considers any error to + // have potentially corrupted the state of the connection, it dies + // on any errors. So to avoid timeout errors in rxMsg we set the + // deadline and peek into the reader. If a timeout error occurs there + // we don't break the pgx connection. If the Peek returns that data + // is available then we turn off the read deadline before the rxMsg. + err := c.conn.SetReadDeadline(stopTime) + if err != nil { + return nil, err + } + + // Wait until there is a byte available before continuing onto the normal msg reading path + _, err = c.reader.Peek(1) + if err != nil { + c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline + if err, ok := err.(*net.OpError); ok && err.Timeout() { + return nil, NotificationTimeoutError + } + return nil, err + } + + err = c.conn.SetReadDeadline(zeroTime) + if err != nil { + return nil, err } var t byte var r *MessageReader if t, r, err = c.rxMsg(); err == nil { if err = c.processContextFreeMsg(t, r); err != nil { - return + return nil, err } } else { - return + return nil, err + } + + if len(c.notifications) > 0 { + notification := c.notifications[0] + c.notifications = c.notifications[1:] + return notification, nil } } } diff --git a/connection_test.go b/connection_test.go index 2bc41459..c8cd22b8 100644 --- a/connection_test.go +++ b/connection_test.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "github.com/JackC/pgx" - "net" "strings" "sync" "testing" @@ -702,12 +701,22 @@ func TestListenNotify(t *testing.T) { // when timeout occurs notification, err = listener.WaitForNotification(time.Millisecond) - if _, ok := err.(*net.OpError); !ok { + if err != pgx.NotificationTimeoutError { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } if notification != nil { t.Errorf("WaitForNotification returned an unexpected notification: %v", notification) } + + // listener can listen again after a timeout + mustExecute(t, notifier, "notify chat") + notification, err = listener.WaitForNotification(time.Second) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "chat" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } } func TestFatalRxError(t *testing.T) {