Restore pgx v4 style CopyFrom implementation

This approach uses an extra goroutine to write while the main goroutine
continues to read. This avoids the need to use non-blocking I/O.
pull/1644/head
Jack Christensen 2023-06-03 09:23:49 -05:00 committed by Jack Christensen
parent 4410fc0a65
commit 85136a8efe
1 changed files with 93 additions and 53 deletions

View File

@ -13,6 +13,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
@ -75,6 +76,11 @@ type PgConn struct {
status byte // One of connStatus* constants
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
@ -419,6 +425,24 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
return ch
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@ -458,7 +482,23 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil
}
msg, err := pgConn.frontend.Receive()
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.frontend.Receive()
}
if err != nil {
// Close on anything other than timeout error - everything else is fatal
@ -1155,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command
// Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush()
if err != nil {
@ -1163,52 +1203,55 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
return CommandTag{}, err
}
// err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
// if err != nil {
// pgConn.asyncClose()
// return CommandTag{}, err
// }
nonblocking := true
defer func() {
if nonblocking {
pgConn.conn.SetReadDeadline(time.Time{})
// Send copy data
abortCopyChan := make(chan struct{})
copyErrChan := make(chan error, 1)
signalMessageChan := pgConn.signalMessage()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := iobufpool.Get(65536)
defer iobufpool.Put(buf)
(*buf)[0] = 'd'
for {
n, readErr := r.Read((*buf)[5:cap(*buf)])
if n > 0 {
*buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
pgConn.conn.Close()
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
select {
case <-abortCopyChan:
return
default:
}
}
}()
buf := iobufpool.Get(65536)
defer iobufpool.Put(buf)
(*buf)[0] = 'd'
var readErr, pgErr error
for pgErr == nil {
// Read chunk from r.
var n int
n, readErr = r.Read((*buf)[5:cap(*buf)])
// Send chunk to PostgreSQL.
if n > 0 {
*buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
}
// Abort loop if there was a read error.
if readErr != nil {
break
}
// Read messages until error or none available.
for pgErr == nil {
var pgErr error
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <-copyErrChan:
case <-signalMessageChan:
msg, err := pgConn.receiveMessage()
if err != nil {
// if errors.Is(err, nbconn.ErrWouldBlock) {
// break
// }
pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
@ -1216,22 +1259,19 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
break
default:
signalMessageChan = pgConn.signalMessage()
}
}
}
close(abortCopyChan)
// Make sure io goroutine finishes before writing.
wg.Wait()
err = pgConn.conn.SetReadDeadline(time.Time{})
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
nonblocking = false
if readErr == io.EOF || pgErr != nil {
if copyErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
}
err = pgConn.frontend.Flush()
if err != nil {