mirror of https://github.com/jackc/pgx.git
Restore pgx v4 style CopyFrom implementation
This approach uses an extra goroutine to write while the main goroutine continues to read. This avoids the need to use non-blocking I/O.pull/1644/head
parent
4410fc0a65
commit
85136a8efe
122
pgconn/pgconn.go
122
pgconn/pgconn.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
|
@ -75,6 +76,11 @@ type PgConn struct {
|
||||||
|
|
||||||
status byte // One of connStatus* constants
|
status byte // One of connStatus* constants
|
||||||
|
|
||||||
|
bufferingReceive bool
|
||||||
|
bufferingReceiveMux sync.Mutex
|
||||||
|
bufferingReceiveMsg pgproto3.BackendMessage
|
||||||
|
bufferingReceiveErr error
|
||||||
|
|
||||||
peekedMsg pgproto3.BackendMessage
|
peekedMsg pgproto3.BackendMessage
|
||||||
|
|
||||||
// Reusable / preallocated resources
|
// Reusable / preallocated resources
|
||||||
|
@ -419,6 +425,24 @@ func hexMD5(s string) string {
|
||||||
return hex.EncodeToString(hash.Sum(nil))
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||||
|
if pgConn.bufferingReceive {
|
||||||
|
panic("BUG: signalMessage when already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn.bufferingReceive = true
|
||||||
|
pgConn.bufferingReceiveMux.Lock()
|
||||||
|
|
||||||
|
ch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
|
||||||
|
pgConn.bufferingReceiveMux.Unlock()
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||||
|
@ -458,7 +482,23 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
||||||
return pgConn.peekedMsg, nil
|
return pgConn.peekedMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := pgConn.frontend.Receive()
|
var msg pgproto3.BackendMessage
|
||||||
|
var err error
|
||||||
|
if pgConn.bufferingReceive {
|
||||||
|
pgConn.bufferingReceiveMux.Lock()
|
||||||
|
msg = pgConn.bufferingReceiveMsg
|
||||||
|
err = pgConn.bufferingReceiveErr
|
||||||
|
pgConn.bufferingReceiveMux.Unlock()
|
||||||
|
pgConn.bufferingReceive = false
|
||||||
|
|
||||||
|
// If a timeout error happened in the background try the read again.
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||||
|
msg, err = pgConn.frontend.Receive()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg, err = pgConn.frontend.Receive()
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Close on anything other than timeout error - everything else is fatal
|
// Close on anything other than timeout error - everything else is fatal
|
||||||
|
@ -1155,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send copy to command
|
// Send copy from query
|
||||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||||
err := pgConn.frontend.Flush()
|
err := pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1163,52 +1203,55 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
return CommandTag{}, err
|
return CommandTag{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
// Send copy data
|
||||||
// if err != nil {
|
abortCopyChan := make(chan struct{})
|
||||||
// pgConn.asyncClose()
|
copyErrChan := make(chan error, 1)
|
||||||
// return CommandTag{}, err
|
signalMessageChan := pgConn.signalMessage()
|
||||||
// }
|
var wg sync.WaitGroup
|
||||||
nonblocking := true
|
wg.Add(1)
|
||||||
defer func() {
|
|
||||||
if nonblocking {
|
|
||||||
pgConn.conn.SetReadDeadline(time.Time{})
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
buf := iobufpool.Get(65536)
|
buf := iobufpool.Get(65536)
|
||||||
defer iobufpool.Put(buf)
|
defer iobufpool.Put(buf)
|
||||||
(*buf)[0] = 'd'
|
(*buf)[0] = 'd'
|
||||||
|
|
||||||
var readErr, pgErr error
|
for {
|
||||||
for pgErr == nil {
|
n, readErr := r.Read((*buf)[5:cap(*buf)])
|
||||||
// Read chunk from r.
|
|
||||||
var n int
|
|
||||||
n, readErr = r.Read((*buf)[5:cap(*buf)])
|
|
||||||
|
|
||||||
// Send chunk to PostgreSQL.
|
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
*buf = (*buf)[0 : n+5]
|
*buf = (*buf)[0 : n+5]
|
||||||
pgio.SetInt32((*buf)[1:], int32(n+4))
|
pgio.SetInt32((*buf)[1:], int32(n+4))
|
||||||
|
|
||||||
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
|
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
pgConn.asyncClose()
|
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
|
||||||
return CommandTag{}, err
|
pgConn.conn.Close()
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort loop if there was a read error.
|
copyErrChan <- writeErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
break
|
copyErrChan <- readErr
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read messages until error or none available.
|
select {
|
||||||
for pgErr == nil {
|
case <-abortCopyChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var pgErr error
|
||||||
|
var copyErr error
|
||||||
|
for copyErr == nil && pgErr == nil {
|
||||||
|
select {
|
||||||
|
case copyErr = <-copyErrChan:
|
||||||
|
case <-signalMessageChan:
|
||||||
msg, err := pgConn.receiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if errors.Is(err, nbconn.ErrWouldBlock) {
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
return CommandTag{}, normalizeTimeoutError(ctx, err)
|
return CommandTag{}, normalizeTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
@ -1216,22 +1259,19 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgErr = ErrorResponseToPgError(msg)
|
pgErr = ErrorResponseToPgError(msg)
|
||||||
break
|
default:
|
||||||
|
signalMessageChan = pgConn.signalMessage()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
close(abortCopyChan)
|
||||||
|
// Make sure io goroutine finishes before writing.
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
err = pgConn.conn.SetReadDeadline(time.Time{})
|
if copyErr == io.EOF || pgErr != nil {
|
||||||
if err != nil {
|
|
||||||
pgConn.asyncClose()
|
|
||||||
return CommandTag{}, err
|
|
||||||
}
|
|
||||||
nonblocking = false
|
|
||||||
|
|
||||||
if readErr == io.EOF || pgErr != nil {
|
|
||||||
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
||||||
} else {
|
} else {
|
||||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
|
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
||||||
}
|
}
|
||||||
err = pgConn.frontend.Flush()
|
err = pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue