From e83d1d2228e858edf309a703de9c612eed9de298 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:20:36 -0600 Subject: [PATCH] Recover from context cancellation during CopyFrom --- pgconn/pgconn.go | 131 ++++++++++++++++++++++++++++++++++++++---- pgconn/pgconn_test.go | 36 ++++++++++++ 2 files changed, 155 insertions(+), 12 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index d8ec6b07..e34853a0 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/pgio" @@ -91,6 +92,11 @@ type PgConn struct { controller chan interface{} closed bool + + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -273,8 +279,42 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { - msg, err := pgConn.Frontend.Receive() + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + msg, err = pgConn.Frontend.Receive() + } + } else { + msg, err = pgConn.Frontend.Receive() + } + if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { @@ -853,7 +893,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co if err != nil { cleanupContextDeadline() if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() + go pgConn.recoverFromTimeoutDuringCopyFrom() } else { <-pgConn.controller } @@ -877,30 +917,56 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = append(buf, 'd') sp := len(buf) var readErr error - for readErr == nil { + signalMessageChan := pgConn.signalMessage() + for readErr == nil && pgErr == nil { n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) + n, err = pgConn.conn.Write(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true - + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } cleanupContextDeadline() - <-pgConn.controller + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeoutDuringCopyFrom() + } else { + <-pgConn.controller + } return "", preferContextOverNetTimeoutError(ctx, err) } } + + select { + case <-signalMessageChan: + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeoutDuringCopyFrom() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + default: + } } buf = buf[:0] - if readErr == io.EOF { + if readErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { @@ -944,6 +1010,47 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } +func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { + // Regardless of recovery outcome the lock on the pgConn must be released. + defer func() { <-pgConn.controller }() + + // Limit time to wait for entire cancellation process. + err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + pgConn.hardClose() + return + } + + copyFail := &pgproto3.CopyFail{Error: "client cancel"} + buf := copyFail.Encode(nil) + + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + return + } + + pendingReadyForQuery := true + + for pendingReadyForQuery { + msg, err := pgConn.ReceiveMessage() + if err != nil { + pgConn.hardClose() + return + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + pendingReadyForQuery = false + } + } + + err = pgConn.conn.SetDeadline(time.Time{}) + if err != nil { + pgConn.hardClose() + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 47b3b3fb..7fb01e2c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "io/ioutil" "log" "net" @@ -830,6 +831,41 @@ func TestConnCopyFrom(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFromCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + cancel() + assert.Equal(t, int64(0), ct.RowsAffected()) + require.Equal(t, context.DeadlineExceeded, err) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel()