From 5d2be99c254e76f7dfb8b481db1791dd613b5d4c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Apr 2020 19:38:21 -0500 Subject: [PATCH] Fix panic when closing conn during cancellable query fixes #29 --- internal/ctxwatch/context_watcher_test.go | 11 +++++++++++ pgconn.go | 7 +++++++ pgconn_test.go | 13 +++++++++++++ 3 files changed, 31 insertions(+) diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 0b491bf8..6348b729 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -59,6 +59,17 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } +func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 diff --git a/pgconn.go b/pgconn.go index 6155281d..d5a424ac 100644 --- a/pgconn.go +++ b/pgconn.go @@ -494,6 +494,13 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } diff --git a/pgconn_test.go b/pgconn_test.go index 17b40343..e29a36b2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1708,6 +1708,19 @@ func TestHijackAndConstruct(t *testing.T) { ensureConnValid(t, newConn) } +func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + ctx, _ := context.WithCancel(context.Background()) + pgConn.Exec(ctx, "select n from generate_series(1,10) n") + + closeCtx, _ := context.WithCancel(context.Background()) + pgConn.Close(closeCtx) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {