From 9ce1b2b16e67ac81291980b5ed29d22c8980bc89 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 24 Mar 2016 15:26:44 -0500 Subject: [PATCH] Fix Listen/Unlisten with special characters fixes #132 --- CHANGELOG.md | 1 + conn.go | 8 ++++++-- conn_test.go | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3eaf5d87..4d168cc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Fix compilation on 32-bit architecture * Fix Tx.status not being set on error on Commit +* Fix Listen/Unlisten with special characters # 2.8.0 (March 18, 2016) diff --git a/conn.go b/conn.go index 12290507..c652b4fe 100644 --- a/conn.go +++ b/conn.go @@ -657,7 +657,7 @@ func (c *Conn) Deallocate(name string) (err error) { // Listen establishes a PostgreSQL listen/notify to channel func (c *Conn) Listen(channel string) error { - _, err := c.Exec("listen " + channel) + _, err := c.Exec("listen " + quoteIdentifier(channel)) if err != nil { return err } @@ -669,7 +669,7 @@ func (c *Conn) Listen(channel string) error { // Unlisten unsubscribes from a listen channel func (c *Conn) Unlisten(channel string) error { - _, err := c.Exec("unlisten " + channel) + _, err := c.Exec("unlisten " + quoteIdentifier(channel)) if err != nil { return err } @@ -1205,3 +1205,7 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) { c.logLevel = lvl return lvl, nil } + +func quoteIdentifier(s string) string { + return `"` + strings.Replace(s, `"`, `""`, -1) + `"` +} diff --git a/conn_test.go b/conn_test.go index eac9f840..bcab2b7d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1185,6 +1185,22 @@ func TestListenNotifySelfNotification(t *testing.T) { } } +func TestListenUnlistenSpecialCharacters(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + chanName := "special characters !@#{$%^&*()}" + if err := conn.Listen(chanName); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + if err := conn.Unlisten(chanName); err != nil { + t.Fatalf("Unable to stop listening: %v", err) + } +} + func TestFatalRxError(t *testing.T) { t.Parallel()