mirror of https://github.com/jackc/pgx.git
Recover from context cancellation during CopyFrom
parent
68d6d1c779
commit
e83d1d2228
131
pgconn/pgconn.go
131
pgconn/pgconn.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgx/pgio"
|
||||||
|
@ -91,6 +92,11 @@ type PgConn struct {
|
||||||
controller chan interface{}
|
controller chan interface{}
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
|
bufferingReceive bool
|
||||||
|
bufferingReceiveMux sync.Mutex
|
||||||
|
bufferingReceiveMsg pgproto3.BackendMessage
|
||||||
|
bufferingReceiveErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
|
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
|
||||||
|
@ -273,8 +279,42 @@ func hexMD5(s string) string {
|
||||||
return hex.EncodeToString(hash.Sum(nil))
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||||
|
if pgConn.bufferingReceive {
|
||||||
|
panic("BUG: signalMessage when already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn.bufferingReceive = true
|
||||||
|
pgConn.bufferingReceiveMux.Lock()
|
||||||
|
|
||||||
|
ch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive()
|
||||||
|
pgConn.bufferingReceiveMux.Unlock()
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
msg, err := pgConn.Frontend.Receive()
|
var msg pgproto3.BackendMessage
|
||||||
|
var err error
|
||||||
|
if pgConn.bufferingReceive {
|
||||||
|
pgConn.bufferingReceiveMux.Lock()
|
||||||
|
msg = pgConn.bufferingReceiveMsg
|
||||||
|
err = pgConn.bufferingReceiveErr
|
||||||
|
pgConn.bufferingReceiveMux.Unlock()
|
||||||
|
pgConn.bufferingReceive = false
|
||||||
|
|
||||||
|
// If a timeout error happened in the background try the read again.
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
msg, err = pgConn.Frontend.Receive()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg, err = pgConn.Frontend.Receive()
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Close on anything other than timeout error - everything else is fatal
|
// Close on anything other than timeout error - everything else is fatal
|
||||||
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
|
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
|
||||||
|
@ -853,7 +893,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanupContextDeadline()
|
cleanupContextDeadline()
|
||||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
go pgConn.recoverFromTimeout()
|
go pgConn.recoverFromTimeoutDuringCopyFrom()
|
||||||
} else {
|
} else {
|
||||||
<-pgConn.controller
|
<-pgConn.controller
|
||||||
}
|
}
|
||||||
|
@ -877,30 +917,56 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
buf = append(buf, 'd')
|
buf = append(buf, 'd')
|
||||||
sp := len(buf)
|
sp := len(buf)
|
||||||
var readErr error
|
var readErr error
|
||||||
for readErr == nil {
|
signalMessageChan := pgConn.signalMessage()
|
||||||
|
for readErr == nil && pgErr == nil {
|
||||||
n, readErr = r.Read(buf[5:cap(buf)])
|
n, readErr = r.Read(buf[5:cap(buf)])
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
buf = buf[0 : n+5]
|
buf = buf[0 : n+5]
|
||||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||||
|
|
||||||
_, err = pgConn.conn.Write(buf)
|
n, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
|
// Partially sent messages are a fatal error for the connection.
|
||||||
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
|
if n > 0 {
|
||||||
// close the connection.
|
// Close connection because cannot recover from partially sent message.
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
pgConn.closed = true
|
pgConn.closed = true
|
||||||
|
}
|
||||||
cleanupContextDeadline()
|
cleanupContextDeadline()
|
||||||
<-pgConn.controller
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
go pgConn.recoverFromTimeoutDuringCopyFrom()
|
||||||
|
} else {
|
||||||
|
<-pgConn.controller
|
||||||
|
}
|
||||||
|
|
||||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-signalMessageChan:
|
||||||
|
msg, err := pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
cleanupContextDeadline()
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
go pgConn.recoverFromTimeoutDuringCopyFrom()
|
||||||
|
} else {
|
||||||
|
<-pgConn.controller
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
pgErr = errorResponseToPgError(msg)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buf = buf[:0]
|
buf = buf[:0]
|
||||||
if readErr == io.EOF {
|
if readErr == io.EOF || pgErr != nil {
|
||||||
copyDone := &pgproto3.CopyDone{}
|
copyDone := &pgproto3.CopyDone{}
|
||||||
buf = copyDone.Encode(buf)
|
buf = copyDone.Encode(buf)
|
||||||
} else {
|
} else {
|
||||||
|
@ -944,6 +1010,47 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() {
|
||||||
|
// Regardless of recovery outcome the lock on the pgConn must be released.
|
||||||
|
defer func() { <-pgConn.controller }()
|
||||||
|
|
||||||
|
// Limit time to wait for entire cancellation process.
|
||||||
|
err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
pgConn.hardClose()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copyFail := &pgproto3.CopyFail{Error: "client cancel"}
|
||||||
|
buf := copyFail.Encode(nil)
|
||||||
|
|
||||||
|
_, err = pgConn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
pgConn.hardClose()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingReadyForQuery := true
|
||||||
|
|
||||||
|
for pendingReadyForQuery {
|
||||||
|
msg, err := pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
pgConn.hardClose()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.(type) {
|
||||||
|
case *pgproto3.ReadyForQuery:
|
||||||
|
pendingReadyForQuery = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pgConn.conn.SetDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
pgConn.hardClose()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
||||||
type MultiResultReader struct {
|
type MultiResultReader struct {
|
||||||
pgConn *PgConn
|
pgConn *PgConn
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
@ -830,6 +831,41 @@ func TestConnCopyFrom(t *testing.T) {
|
||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, w := io.Pipe()
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 1000000; i++ {
|
||||||
|
a := strconv.Itoa(i)
|
||||||
|
b := "foo " + a + " bar"
|
||||||
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(time.Microsecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||||
|
cancel()
|
||||||
|
assert.Equal(t, int64(0), ct.RowsAffected())
|
||||||
|
require.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnCopyFromGzipReader(t *testing.T) {
|
func TestConnCopyFromGzipReader(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue