diff --git a/conn.go b/conn.go index 462fc693..dbf6ac68 100644 --- a/conn.go +++ b/conn.go @@ -49,6 +49,8 @@ type Conn struct { causeOfDeath error + notifications []*pgconn.Notification + doneChan chan struct{} closedChan chan error @@ -144,13 +146,21 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { panic("config must be created by ParseConfig") } - c = new(Conn) + c = &Conn{ + config: config, + ConnInfo: pgtype.NewConnInfo(), + logLevel: config.LogLevel, + logger: config.Logger, + } - c.config = config - c.ConnInfo = pgtype.NewConnInfo() - - c.logLevel = c.config.LogLevel - c.logger = c.config.Logger + // Only install pgx notification system if no other callback handler is present. + if config.Config.OnNotification == nil { + config.Config.OnNotification = c.bufferNotifications + } else { + if c.shouldLog(LogLevelDebug) { + c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host}) + } + } if c.shouldLog(LogLevelInfo) { c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) @@ -247,6 +257,30 @@ func (c *Conn) Deallocate(ctx context.Context, name string) error { return err } +func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) { + c.notifications = append(c.notifications, n) +} + +// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a +// slightly more convenient form. +func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { + var n *pgconn.Notification + + // Return already received notification immediately + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + return n, nil + } + + err := c.pgConn.WaitForNotification(ctx) + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + } + return n, err +} + func (c *Conn) IsAlive() bool { return c.pgConn.IsAlive() } diff --git a/conn_test.go b/conn_test.go index 2437d91c..0d25cb8d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" errors "golang.org/x/xerrors" ) @@ -410,6 +411,140 @@ func TestPrepareIdempotency(t *testing.T) { } } +func TestListenNotify(t *testing.T) { + t.Parallel() + + listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, listener) + + mustExec(t, listener, "listen chat") + + notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, notifier) + + mustExec(t, notifier, "notify chat") + + // when notification is waiting on the socket to be read + notification, err := listener.WaitForNotification(context.Background()) + require.NoError(t, err) + assert.Equal(t, "chat", notification.Channel) + + // when notification has already been read during previous query + mustExec(t, notifier, "notify chat") + rows, _ := listener.Query(context.Background(), "select 1") + rows.Close() + require.NoError(t, rows.Err()) + + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + notification, err = listener.WaitForNotification(ctx) + require.NoError(t, err) + assert.Equal(t, "chat", notification.Channel) + + // when timeout occurs + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + notification, err = listener.WaitForNotification(ctx) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) + + // listener can listen again after a timeout + mustExec(t, notifier, "notify chat") + notification, err = listener.WaitForNotification(context.Background()) + require.NoError(t, err) + assert.Equal(t, "chat", notification.Channel) +} + +func TestListenNotifyWhileBusyIsSafe(t *testing.T) { + t.Parallel() + + listenerDone := make(chan bool) + go func() { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + defer func() { + listenerDone <- true + }() + + mustExec(t, conn, "listen busysafe") + + for i := 0; i < 5000; i++ { + var sum int32 + var rowCount int32 + + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 100) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if sum != 5050 { + t.Fatalf("Wrong rows sum: %v", sum) + } + + if rowCount != 100 { + t.Fatalf("Wrong number of rows: %v", rowCount) + } + + time.Sleep(1 * time.Microsecond) + } + }() + + go func() { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + for i := 0; i < 100000; i++ { + mustExec(t, conn, "notify busysafe, 'hello'") + time.Sleep(1 * time.Microsecond) + } + }() + + <-listenerDone +} + +func TestListenNotifySelfNotification(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "listen self") + + // Notify self and WaitForNotification immediately + mustExec(t, conn, "notify self") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + notification, err := conn.WaitForNotification(ctx) + require.NoError(t, err) + assert.Equal(t, "self", notification.Channel) + + // Notify self and do something else before WaitForNotification + mustExec(t, conn, "notify self") + + rows, _ := conn.Query(context.Background(), "select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Unexpected error on Query: %v", rows.Err()) + } + + ctx, cncl := context.WithTimeout(context.Background(), time.Second) + defer cncl() + notification, err = conn.WaitForNotification(ctx) + require.NoError(t, err) + assert.Equal(t, "self", notification.Channel) +} + func TestFatalRxError(t *testing.T) { t.Parallel() diff --git a/examples/chat/main.go b/examples/chat/main.go index 54ffdd1a..6be4ee1c 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -55,17 +55,19 @@ func listen() { } defer conn.Release() - // TODO - determine how listen should be handled in pgx vs. pgconn + _, err = conn.Exec(context.Background(), "listen chat") + if err != nil { + fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err) + os.Exit(1) + } - conn.Exec(context.Background(), "listen chat") + for { + notification, err := conn.Conn().WaitForNotification(context.Background()) + if err != nil { + fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) + os.Exit(1) + } - // for { - // notification, err := conn.WaitForNotification(context.Background()) - // if err != nil { - // fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) - // os.Exit(1) - // } - - // fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) - // } + fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) + } }