mirror of https://github.com/jackc/pgx.git
Restore pgx v4 style CopyFrom implementation
This approach uses an extra goroutine to write while the main goroutine continues to read. This avoids the need to use non-blocking I/O.pull/1644/head
parent
4410fc0a65
commit
85136a8efe
146
pgconn/pgconn.go
146
pgconn/pgconn.go
|
@ -13,6 +13,7 @@ import (
|
|||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
|
@ -75,6 +76,11 @@ type PgConn struct {
|
|||
|
||||
status byte // One of connStatus* constants
|
||||
|
||||
bufferingReceive bool
|
||||
bufferingReceiveMux sync.Mutex
|
||||
bufferingReceiveMsg pgproto3.BackendMessage
|
||||
bufferingReceiveErr error
|
||||
|
||||
peekedMsg pgproto3.BackendMessage
|
||||
|
||||
// Reusable / preallocated resources
|
||||
|
@ -419,6 +425,24 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||
if pgConn.bufferingReceive {
|
||||
panic("BUG: signalMessage when already in progress")
|
||||
}
|
||||
|
||||
pgConn.bufferingReceive = true
|
||||
pgConn.bufferingReceiveMux.Lock()
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
|
||||
pgConn.bufferingReceiveMux.Unlock()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||
|
@ -458,7 +482,23 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
|||
return pgConn.peekedMsg, nil
|
||||
}
|
||||
|
||||
msg, err := pgConn.frontend.Receive()
|
||||
var msg pgproto3.BackendMessage
|
||||
var err error
|
||||
if pgConn.bufferingReceive {
|
||||
pgConn.bufferingReceiveMux.Lock()
|
||||
msg = pgConn.bufferingReceiveMsg
|
||||
err = pgConn.bufferingReceiveErr
|
||||
pgConn.bufferingReceiveMux.Unlock()
|
||||
pgConn.bufferingReceive = false
|
||||
|
||||
// If a timeout error happened in the background try the read again.
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
msg, err = pgConn.frontend.Receive()
|
||||
}
|
||||
} else {
|
||||
msg, err = pgConn.frontend.Receive()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Close on anything other than timeout error - everything else is fatal
|
||||
|
@ -1155,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
defer pgConn.contextWatcher.Unwatch()
|
||||
}
|
||||
|
||||
// Send copy to command
|
||||
// Send copy from query
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
err := pgConn.frontend.Flush()
|
||||
if err != nil {
|
||||
|
@ -1163,52 +1203,55 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
return CommandTag{}, err
|
||||
}
|
||||
|
||||
// err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
// if err != nil {
|
||||
// pgConn.asyncClose()
|
||||
// return CommandTag{}, err
|
||||
// }
|
||||
nonblocking := true
|
||||
defer func() {
|
||||
if nonblocking {
|
||||
pgConn.conn.SetReadDeadline(time.Time{})
|
||||
// Send copy data
|
||||
abortCopyChan := make(chan struct{})
|
||||
copyErrChan := make(chan error, 1)
|
||||
signalMessageChan := pgConn.signalMessage()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := iobufpool.Get(65536)
|
||||
defer iobufpool.Put(buf)
|
||||
(*buf)[0] = 'd'
|
||||
|
||||
for {
|
||||
n, readErr := r.Read((*buf)[5:cap(*buf)])
|
||||
if n > 0 {
|
||||
*buf = (*buf)[0 : n+5]
|
||||
pgio.SetInt32((*buf)[1:], int32(n+4))
|
||||
|
||||
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
|
||||
if writeErr != nil {
|
||||
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
|
||||
pgConn.conn.Close()
|
||||
|
||||
copyErrChan <- writeErr
|
||||
return
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
copyErrChan <- readErr
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-abortCopyChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
buf := iobufpool.Get(65536)
|
||||
defer iobufpool.Put(buf)
|
||||
(*buf)[0] = 'd'
|
||||
|
||||
var readErr, pgErr error
|
||||
for pgErr == nil {
|
||||
// Read chunk from r.
|
||||
var n int
|
||||
n, readErr = r.Read((*buf)[5:cap(*buf)])
|
||||
|
||||
// Send chunk to PostgreSQL.
|
||||
if n > 0 {
|
||||
*buf = (*buf)[0 : n+5]
|
||||
pgio.SetInt32((*buf)[1:], int32(n+4))
|
||||
|
||||
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
|
||||
if writeErr != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Abort loop if there was a read error.
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Read messages until error or none available.
|
||||
for pgErr == nil {
|
||||
var pgErr error
|
||||
var copyErr error
|
||||
for copyErr == nil && pgErr == nil {
|
||||
select {
|
||||
case copyErr = <-copyErrChan:
|
||||
case <-signalMessageChan:
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
// if errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
// break
|
||||
// }
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, normalizeTimeoutError(ctx, err)
|
||||
}
|
||||
|
@ -1216,22 +1259,19 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = ErrorResponseToPgError(msg)
|
||||
break
|
||||
default:
|
||||
signalMessageChan = pgConn.signalMessage()
|
||||
}
|
||||
}
|
||||
}
|
||||
close(abortCopyChan)
|
||||
// Make sure io goroutine finishes before writing.
|
||||
wg.Wait()
|
||||
|
||||
err = pgConn.conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
}
|
||||
nonblocking = false
|
||||
|
||||
if readErr == io.EOF || pgErr != nil {
|
||||
if copyErr == io.EOF || pgErr != nil {
|
||||
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
||||
} else {
|
||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
|
||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
||||
}
|
||||
err = pgConn.frontend.Flush()
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue