mirror of https://github.com/jackc/pgx.git
parent
9e321af35c
commit
b271dd5bf1
|
@ -12,6 +12,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionParameters contains all the options used to establish a connection.
|
||||
|
@ -36,6 +37,7 @@ type Connection struct {
|
|||
parameters ConnectionParameters // parameters used when establishing this connection
|
||||
TxStatus byte
|
||||
preparedStatements map[string]*preparedStatement
|
||||
notifications []*Notification
|
||||
}
|
||||
|
||||
type preparedStatement struct {
|
||||
|
@ -44,6 +46,12 @@ type preparedStatement struct {
|
|||
ParameterOids []Oid
|
||||
}
|
||||
|
||||
type Notification struct {
|
||||
Pid int32 // backend pid that sent the notification
|
||||
Channel string // channel from which notification was received
|
||||
Payload string
|
||||
}
|
||||
|
||||
// NotSingleRowError is returned when exactly 1 row is expected, but 0 or more than
|
||||
// 1 row is returned
|
||||
type NotSingleRowError struct {
|
||||
|
@ -336,6 +344,45 @@ func (c *Connection) Deallocate(name string) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// Listen establishes a PostgreSQL listen/notify to channel
|
||||
func (c *Connection) Listen(channel string) (err error) {
|
||||
_, err = c.Execute("listen " + channel)
|
||||
return
|
||||
}
|
||||
|
||||
// WaitForNotification waits for a PostgreSQL notification for up to timeout
|
||||
func (c *Connection) WaitForNotification(timeout time.Duration) (notification *Notification, err error) {
|
||||
err = c.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
var zeroTime time.Time
|
||||
e := c.conn.SetReadDeadline(zeroTime)
|
||||
if err == nil && e != nil {
|
||||
err = e
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if len(c.notifications) > 0 {
|
||||
notification = c.notifications[0]
|
||||
c.notifications = c.notifications[1:]
|
||||
return
|
||||
}
|
||||
|
||||
var t byte
|
||||
var r *MessageReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) {
|
||||
if ps, present := c.preparedStatements[sql]; present {
|
||||
return c.sendPreparedQuery(ps, arguments...)
|
||||
|
@ -525,6 +572,8 @@ func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error)
|
|||
return c.rxErrorResponse(r)
|
||||
case noticeResponse:
|
||||
return nil
|
||||
case notificationResponse:
|
||||
return c.rxNotificationResponse(r)
|
||||
default:
|
||||
return fmt.Errorf("Received unknown message type: %c", t)
|
||||
}
|
||||
|
@ -661,6 +710,15 @@ func (c *Connection) rxCommandComplete(r *MessageReader) string {
|
|||
return r.ReadString()
|
||||
}
|
||||
|
||||
func (c *Connection) rxNotificationResponse(r *MessageReader) (err error) {
|
||||
n := new(Notification)
|
||||
n.Pid = r.ReadInt32()
|
||||
n.Channel = r.ReadString()
|
||||
n.Payload = r.ReadString()
|
||||
c.notifications = append(c.notifications, n)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {
|
||||
_, err = c.conn.Write(msg.Bytes())
|
||||
return
|
||||
|
|
|
@ -4,8 +4,10 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"github.com/JackC/pgx"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
|
@ -565,3 +567,47 @@ func TestTransactionIso(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenNotify(t *testing.T) {
|
||||
listener, err := pgx.Connect(*defaultConnectionParameters)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
if err := listener.Listen("chat"); err != nil {
|
||||
t.Fatalf("Unable to start listening: %v", err)
|
||||
}
|
||||
|
||||
notifier := getSharedConnection()
|
||||
mustExecute(t, notifier, "notify chat")
|
||||
|
||||
// when notification is waiting on the socket to be read
|
||||
notification, err := listener.WaitForNotification(time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "chat" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
|
||||
// when notification has already been read during previous query
|
||||
mustExecute(t, notifier, "notify chat")
|
||||
mustSelectValue(t, listener, "select 1")
|
||||
notification, err = listener.WaitForNotification(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||
}
|
||||
if notification.Channel != "chat" {
|
||||
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||
}
|
||||
|
||||
// when timeout occurs
|
||||
notification, err = listener.WaitForNotification(time.Millisecond)
|
||||
if _, ok := err.(*net.OpError); !ok {
|
||||
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||
}
|
||||
if notification != nil {
|
||||
t.Errorf("WaitForNotification returned an unexpected notification: %v", notification)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ const (
|
|||
parseComplete = '1'
|
||||
parameterDescription = 't'
|
||||
bindComplete = '2'
|
||||
notificationResponse = 'A'
|
||||
)
|
||||
|
||||
type startupMessage struct {
|
||||
|
|
Loading…
Reference in New Issue