diff --git a/connection.go b/connection.go index 70762635..7d32238e 100644 --- a/connection.go +++ b/connection.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "net" + "time" ) // ConnectionParameters contains all the options used to establish a connection. @@ -36,6 +37,7 @@ type Connection struct { parameters ConnectionParameters // parameters used when establishing this connection TxStatus byte preparedStatements map[string]*preparedStatement + notifications []*Notification } type preparedStatement struct { @@ -44,6 +46,12 @@ type preparedStatement struct { ParameterOids []Oid } +type Notification struct { + Pid int32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + // NotSingleRowError is returned when exactly 1 row is expected, but 0 or more than // 1 row is returned type NotSingleRowError struct { @@ -336,6 +344,45 @@ func (c *Connection) Deallocate(name string) (err error) { return } +// Listen establishes a PostgreSQL listen/notify to channel +func (c *Connection) Listen(channel string) (err error) { + _, err = c.Execute("listen " + channel) + return +} + +// WaitForNotification waits for a PostgreSQL notification for up to timeout +func (c *Connection) WaitForNotification(timeout time.Duration) (notification *Notification, err error) { + err = c.conn.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + return + } + defer func() { + var zeroTime time.Time + e := c.conn.SetReadDeadline(zeroTime) + if err == nil && e != nil { + err = e + } + }() + + for { + if len(c.notifications) > 0 { + notification = c.notifications[0] + c.notifications = c.notifications[1:] + return + } + + var t byte + var r *MessageReader + if t, r, err = c.rxMsg(); err == nil { + if err = c.processContextFreeMsg(t, r); err != nil { + return + } + } else { + return + } + } +} + func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) { if ps, present := c.preparedStatements[sql]; present { return c.sendPreparedQuery(ps, arguments...) @@ -525,6 +572,8 @@ func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error) return c.rxErrorResponse(r) case noticeResponse: return nil + case notificationResponse: + return c.rxNotificationResponse(r) default: return fmt.Errorf("Received unknown message type: %c", t) } @@ -661,6 +710,15 @@ func (c *Connection) rxCommandComplete(r *MessageReader) string { return r.ReadString() } +func (c *Connection) rxNotificationResponse(r *MessageReader) (err error) { + n := new(Notification) + n.Pid = r.ReadInt32() + n.Channel = r.ReadString() + n.Payload = r.ReadString() + c.notifications = append(c.notifications, n) + return +} + func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { _, err = c.conn.Write(msg.Bytes()) return diff --git a/connection_test.go b/connection_test.go index 393d35f8..486abe5c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -4,8 +4,10 @@ import ( "bytes" "fmt" "github.com/JackC/pgx" + "net" "strings" "testing" + "time" ) func TestConnect(t *testing.T) { @@ -565,3 +567,47 @@ func TestTransactionIso(t *testing.T) { } } } + +func TestListenNotify(t *testing.T) { + listener, err := pgx.Connect(*defaultConnectionParameters) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer listener.Close() + + if err := listener.Listen("chat"); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + notifier := getSharedConnection() + mustExecute(t, notifier, "notify chat") + + // when notification is waiting on the socket to be read + notification, err := listener.WaitForNotification(time.Millisecond) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "chat" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } + + // when notification has already been read during previous query + mustExecute(t, notifier, "notify chat") + mustSelectValue(t, listener, "select 1") + notification, err = listener.WaitForNotification(0) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "chat" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } + + // when timeout occurs + notification, err = listener.WaitForNotification(time.Millisecond) + if _, ok := err.(*net.OpError); !ok { + t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) + } + if notification != nil { + t.Errorf("WaitForNotification returned an unexpected notification: %v", notification) + } +} diff --git a/messages.go b/messages.go index 8784dee7..a8231ddf 100644 --- a/messages.go +++ b/messages.go @@ -20,6 +20,7 @@ const ( parseComplete = '1' parameterDescription = 't' bindComplete = '2' + notificationResponse = 'A' ) type startupMessage struct {