Recover from context cancellation during CopyFrom

pull/483/head
Jack Christensen 2019-01-26 12:20:36 -06:00
parent 68d6d1c779
commit e83d1d2228
2 changed files with 155 additions and 12 deletions

View File

@ -12,6 +12,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/pgio"
@ -91,6 +92,11 @@ type PgConn struct {
controller chan interface{}
closed bool
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
@ -273,8 +279,42 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
return ch
}
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.Frontend.Receive()
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
msg, err = pgConn.Frontend.Receive()
}
} else {
msg, err = pgConn.Frontend.Receive()
}
if err != nil {
// Close on anything other than timeout error - everything else is fatal
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
@ -853,7 +893,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeout()
go pgConn.recoverFromTimeoutDuringCopyFrom()
} else {
<-pgConn.controller
}
@ -877,30 +917,56 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
buf = append(buf, 'd')
sp := len(buf)
var readErr error
for readErr == nil {
signalMessageChan := pgConn.signalMessage()
for readErr == nil && pgErr == nil {
n, readErr = r.Read(buf[5:cap(buf)])
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
_, err = pgConn.conn.Write(buf)
n, err = pgConn.conn.Write(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
// close the connection.
pgConn.conn.Close()
pgConn.closed = true
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline()
<-pgConn.controller
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom()
} else {
<-pgConn.controller
}
return "", preferContextOverNetTimeoutError(ctx, err)
}
}
select {
case <-signalMessageChan:
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom()
} else {
<-pgConn.controller
}
return "", preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
}
default:
}
}
buf = buf[:0]
if readErr == io.EOF {
if readErr == io.EOF || pgErr != nil {
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
} else {
@ -944,6 +1010,47 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
}
func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() {
// Regardless of recovery outcome the lock on the pgConn must be released.
defer func() { <-pgConn.controller }()
// Limit time to wait for entire cancellation process.
err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second))
if err != nil {
pgConn.hardClose()
return
}
copyFail := &pgproto3.CopyFail{Error: "client cancel"}
buf := copyFail.Encode(nil)
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return
}
pendingReadyForQuery := true
for pendingReadyForQuery {
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return
}
switch msg.(type) {
case *pgproto3.ReadyForQuery:
pendingReadyForQuery = false
}
}
err = pgConn.conn.SetDeadline(time.Time{})
if err != nil {
pgConn.hardClose()
}
}
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn

View File

@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"log"
"net"
@ -830,6 +831,41 @@ func TestConnCopyFrom(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnCopyFromCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
r, w := io.Pipe()
go func() {
for i := 0; i < 1000000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
if err != nil {
return
}
time.Sleep(time.Microsecond)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel()
assert.Equal(t, int64(0), ct.RowsAffected())
require.Equal(t, context.DeadlineExceeded, err)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromGzipReader(t *testing.T) {
t.Parallel()