From 2061295f7ff3193dac706f49c5896bacb38db002 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 14 Jan 2019 20:51:53 -0600 Subject: [PATCH] Add PgConn.WaitForNotification --- pgconn/pgconn.go | 25 ++++++++++++++++++++++ pgconn/pgconn_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index b2ffe7ca..efd7686f 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -565,6 +565,31 @@ func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { return nil } +// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() + defer func() { <-pgConn.controller }() + + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index ad538257..07e54c75 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -629,6 +629,56 @@ func TestConnOnNotification(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnWaitForNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.Nil(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.Nil(t, err) + + err = pgConn.WaitForNotification(context.Background()) + require.Nil(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationTimeout(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + err = pgConn.WaitForNotification(ctx) + cancel() + require.Equal(t, context.DeadlineExceeded, err) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil {