From 482e56a79be9b32b5991175ab09080977fd9278d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2023 20:35:53 -0500 Subject: [PATCH] Fix race condition when CopyFrom is cancelled. --- pgconn/pgconn.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 5143244d..8b5e9b87 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1229,7 +1229,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) if writeErr != nil { - // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not + // setting pgConn.status or closing pgConn.cleanupDone for the same reason. pgConn.conn.Close() copyErrChan <- writeErr @@ -1255,11 +1256,16 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case copyErr = <-copyErrChan: case <-signalMessageChan: - msg, err := pgConn.receiveMessage() - if err != nil { - pgConn.asyncClose() + // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with + // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an + // error is found then forcibly close the connection without sending the Terminate message. + if err := pgConn.bufferingReceiveErr; err != nil { + pgConn.status = connStatusClosed + pgConn.conn.Close() + close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } + msg, _ := pgConn.receiveMessage() switch msg := msg.(type) { case *pgproto3.ErrorResponse: