Restore pgx v4 style CopyFrom implementation

This approach uses an extra goroutine to write while the main goroutine
continues to read. This avoids the need to use non-blocking I/O.
pull/1644/head
Jack Christensen 2023-06-03 09:23:49 -05:00 committed by Jack Christensen
parent 4410fc0a65
commit 85136a8efe
1 changed files with 93 additions and 53 deletions

View File

@ -13,6 +13,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/iobufpool"
@ -75,6 +76,11 @@ type PgConn struct {
status byte // One of connStatus* constants status byte // One of connStatus* constants
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources // Reusable / preallocated resources
@ -419,6 +425,24 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }
func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
return ch
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@ -458,7 +482,23 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil return pgConn.peekedMsg, nil
} }
msg, err := pgConn.frontend.Receive() var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.frontend.Receive()
}
if err != nil { if err != nil {
// Close on anything other than timeout error - everything else is fatal // Close on anything other than timeout error - everything else is fatal
@ -1155,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
} }
// Send copy to command // Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush() err := pgConn.frontend.Flush()
if err != nil { if err != nil {
@ -1163,52 +1203,55 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
return CommandTag{}, err return CommandTag{}, err
} }
// err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) // Send copy data
// if err != nil { abortCopyChan := make(chan struct{})
// pgConn.asyncClose() copyErrChan := make(chan error, 1)
// return CommandTag{}, err signalMessageChan := pgConn.signalMessage()
// } var wg sync.WaitGroup
nonblocking := true wg.Add(1)
defer func() {
if nonblocking {
pgConn.conn.SetReadDeadline(time.Time{})
}
}()
go func() {
defer wg.Done()
buf := iobufpool.Get(65536) buf := iobufpool.Get(65536)
defer iobufpool.Put(buf) defer iobufpool.Put(buf)
(*buf)[0] = 'd' (*buf)[0] = 'd'
var readErr, pgErr error for {
for pgErr == nil { n, readErr := r.Read((*buf)[5:cap(*buf)])
// Read chunk from r.
var n int
n, readErr = r.Read((*buf)[5:cap(*buf)])
// Send chunk to PostgreSQL.
if n > 0 { if n > 0 {
*buf = (*buf)[0 : n+5] *buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4)) pgio.SetInt32((*buf)[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil { if writeErr != nil {
pgConn.asyncClose() // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
return CommandTag{}, err pgConn.conn.Close()
}
}
// Abort loop if there was a read error. copyErrChan <- writeErr
return
}
}
if readErr != nil { if readErr != nil {
break copyErrChan <- readErr
return
} }
// Read messages until error or none available. select {
for pgErr == nil { case <-abortCopyChan:
return
default:
}
}
}()
var pgErr error
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <-copyErrChan:
case <-signalMessageChan:
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
// if errors.Is(err, nbconn.ErrWouldBlock) {
// break
// }
pgConn.asyncClose() pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err) return CommandTag{}, normalizeTimeoutError(ctx, err)
} }
@ -1216,22 +1259,19 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
break default:
signalMessageChan = pgConn.signalMessage()
} }
} }
} }
close(abortCopyChan)
// Make sure io goroutine finishes before writing.
wg.Wait()
err = pgConn.conn.SetReadDeadline(time.Time{}) if copyErr == io.EOF || pgErr != nil {
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
nonblocking = false
if readErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{}) pgConn.frontend.Send(&pgproto3.CopyDone{})
} else { } else {
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
} }
err = pgConn.frontend.Flush() err = pgConn.frontend.Flush()
if err != nil { if err != nil {