mirror of https://github.com/jackc/pgx.git
wip
parent
2e7b46d5d7
commit
ca22396789
|
@ -0,0 +1,76 @@
|
||||||
|
package nbbconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
const minBufferQueueLen = 8
|
||||||
|
|
||||||
|
type bufferQueue struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
queue [][]byte
|
||||||
|
r, w int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) pushBack(buf []byte) {
|
||||||
|
bq.lock.Lock()
|
||||||
|
defer bq.lock.Unlock()
|
||||||
|
|
||||||
|
if bq.w >= len(bq.queue) {
|
||||||
|
bq.growQueue()
|
||||||
|
}
|
||||||
|
bq.queue[bq.w] = buf
|
||||||
|
bq.w++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) pushFront(buf []byte) {
|
||||||
|
bq.lock.Lock()
|
||||||
|
defer bq.lock.Unlock()
|
||||||
|
|
||||||
|
if bq.w >= len(bq.queue) {
|
||||||
|
bq.growQueue()
|
||||||
|
}
|
||||||
|
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
|
||||||
|
bq.queue[bq.r] = buf
|
||||||
|
bq.w++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) popFront() []byte {
|
||||||
|
bq.lock.Lock()
|
||||||
|
defer bq.lock.Unlock()
|
||||||
|
|
||||||
|
if bq.r == bq.w {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bq.queue[bq.r]
|
||||||
|
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
|
||||||
|
bq.r++
|
||||||
|
|
||||||
|
if bq.r == bq.w {
|
||||||
|
bq.r = 0
|
||||||
|
bq.w = 0
|
||||||
|
if len(bq.queue) > minBufferQueueLen {
|
||||||
|
bq.queue = make([][]byte, minBufferQueueLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) growQueue() {
|
||||||
|
desiredLen := (len(bq.queue) + 1) * 3 / 2
|
||||||
|
if desiredLen < minBufferQueueLen {
|
||||||
|
desiredLen = minBufferQueueLen
|
||||||
|
}
|
||||||
|
|
||||||
|
newQueue := make([][]byte, desiredLen)
|
||||||
|
copy(newQueue, bq.queue)
|
||||||
|
bq.queue = newQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseBuf(buf []byte) {
|
||||||
|
iobufpool.Put(buf[:cap(buf)])
|
||||||
|
}
|
|
@ -13,10 +13,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var errClosed = errors.New("closed")
|
var errClosed = errors.New("closed")
|
||||||
var errWouldBlock = errors.New("would block")
|
var ErrWouldBlock = errors.New("would block")
|
||||||
|
|
||||||
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
||||||
|
|
||||||
|
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
// Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn.
|
// Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn.
|
||||||
//
|
//
|
||||||
// It is designed to solve three problems.
|
// It is designed to solve three problems.
|
||||||
|
@ -37,6 +39,7 @@ type Conn struct {
|
||||||
|
|
||||||
readDeadlineLock sync.Mutex
|
readDeadlineLock sync.Mutex
|
||||||
readDeadline time.Time
|
readDeadline time.Time
|
||||||
|
readNonblocking bool
|
||||||
|
|
||||||
writeDeadlineLock sync.Mutex
|
writeDeadlineLock sync.Mutex
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
|
@ -74,9 +77,19 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
releaseBuf(buf)
|
releaseBuf(buf)
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
|
// TODO - must return error if n != len(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.netConn.Read(b)
|
var readNonblocking bool
|
||||||
|
c.readDeadlineLock.Lock()
|
||||||
|
readNonblocking = c.readNonblocking
|
||||||
|
c.readDeadlineLock.Unlock()
|
||||||
|
|
||||||
|
if readNonblocking {
|
||||||
|
return c.nonblockingRead(b)
|
||||||
|
} else {
|
||||||
|
return c.netConn.Read(b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
||||||
|
@ -123,22 +136,16 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.netConn.RemoteAddr()
|
return c.netConn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
if c.isClosed() {
|
err := c.SetReadDeadline(t)
|
||||||
return errClosed
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
return c.SetWriteDeadline(t)
|
||||||
c.readDeadlineLock.Lock()
|
|
||||||
defer c.readDeadlineLock.Unlock()
|
|
||||||
c.readDeadline = t
|
|
||||||
|
|
||||||
c.writeDeadlineLock.Lock()
|
|
||||||
defer c.writeDeadlineLock.Unlock()
|
|
||||||
c.writeDeadline = t
|
|
||||||
|
|
||||||
return c.netConn.SetDeadline(t)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
if c.isClosed() {
|
if c.isClosed() {
|
||||||
return errClosed
|
return errClosed
|
||||||
|
@ -146,6 +153,14 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
|
||||||
c.readDeadlineLock.Lock()
|
c.readDeadlineLock.Lock()
|
||||||
defer c.readDeadlineLock.Unlock()
|
defer c.readDeadlineLock.Unlock()
|
||||||
|
|
||||||
|
if t == NonBlockingDeadline {
|
||||||
|
c.readNonblocking = true
|
||||||
|
t = time.Time{}
|
||||||
|
} else {
|
||||||
|
c.readNonblocking = false
|
||||||
|
}
|
||||||
|
|
||||||
c.readDeadline = t
|
c.readDeadline = t
|
||||||
|
|
||||||
return c.netConn.SetReadDeadline(t)
|
return c.netConn.SetReadDeadline(t)
|
||||||
|
@ -193,7 +208,7 @@ func (c *Conn) flush() error {
|
||||||
n, err := c.nonblockingWrite(remainingBuf)
|
n, err := c.nonblockingWrite(remainingBuf)
|
||||||
remainingBuf = remainingBuf[n:]
|
remainingBuf = remainingBuf[n:]
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errWouldBlock) {
|
if !errors.Is(err, ErrWouldBlock) {
|
||||||
buf = buf[:len(remainingBuf)]
|
buf = buf[:len(remainingBuf)]
|
||||||
copy(buf, remainingBuf)
|
copy(buf, remainingBuf)
|
||||||
c.writeQueue.pushFront(buf)
|
c.writeQueue.pushFront(buf)
|
||||||
|
@ -234,7 +249,7 @@ func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errWouldBlock) {
|
if !errors.Is(err, ErrWouldBlock) {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -276,7 +291,7 @@ func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
err = errWouldBlock
|
err = ErrWouldBlock
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -305,7 +320,7 @@ func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
err = errWouldBlock
|
err = ErrWouldBlock
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
package nbbconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/nbbconn"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteIsBuffered(t *testing.T) {
|
||||||
|
local, remote := net.Pipe()
|
||||||
|
defer func() {
|
||||||
|
local.Close()
|
||||||
|
remote.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn := nbbconn.New(local)
|
||||||
|
|
||||||
|
// net.Pipe is synchronous so the Write would block if not buffered.
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := conn.Flush()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuf := make([]byte, len(writeBuf))
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||||
|
local, remote := net.Pipe()
|
||||||
|
defer func() {
|
||||||
|
local.Close()
|
||||||
|
remote.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn := nbbconn.New(local)
|
||||||
|
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
readBuf := make([]byte, len(writeBuf))
|
||||||
|
_, err := remote.Read(readBuf)
|
||||||
|
errChan <- err
|
||||||
|
|
||||||
|
_, err = remote.Write([]byte("okay"))
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("okay"), readBuf)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||||
|
local, remote := net.Pipe()
|
||||||
|
defer func() {
|
||||||
|
local.Close()
|
||||||
|
remote.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn := nbbconn.New(local)
|
||||||
|
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
readBuf := make([]byte, len(writeBuf))
|
||||||
|
_, err := remote.Read(readBuf)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNonBlockingRead(t *testing.T) {
|
||||||
|
local, remote := net.Pipe()
|
||||||
|
defer func() {
|
||||||
|
local.Close()
|
||||||
|
remote.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn := nbbconn.New(local)
|
||||||
|
|
||||||
|
err := conn.SetReadDeadline(nbbconn.NonBlockingDeadline)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
require.ErrorIs(t, err, nbbconn.ErrWouldBlock)
|
||||||
|
require.EqualValues(t, 0, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := remote.Write([]byte("okay"))
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = conn.SetReadDeadline(time.Time{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, err = conn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package nbbconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const minQueueLen = 8
|
||||||
|
|
||||||
|
type queue[T any] struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
queue []T
|
||||||
|
r, w int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue[T]) pushBack(item T) {
|
||||||
|
q.lock.Lock()
|
||||||
|
defer q.lock.Unlock()
|
||||||
|
|
||||||
|
if q.w >= len(q.queue) {
|
||||||
|
q.growQueue()
|
||||||
|
}
|
||||||
|
q.queue[q.w] = item
|
||||||
|
q.w++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue[T]) pushFront(item T) {
|
||||||
|
q.lock.Lock()
|
||||||
|
defer q.lock.Unlock()
|
||||||
|
|
||||||
|
if q.w >= len(q.queue) {
|
||||||
|
q.growQueue()
|
||||||
|
}
|
||||||
|
copy(q.queue[q.r+1:q.w+1], q.queue[q.r:q.w])
|
||||||
|
q.queue[q.r] = item
|
||||||
|
q.w++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue[T]) popFront() (T, bool) {
|
||||||
|
q.lock.Lock()
|
||||||
|
defer q.lock.Unlock()
|
||||||
|
|
||||||
|
if q.r == q.w {
|
||||||
|
var zero T
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
item := q.queue[q.r]
|
||||||
|
|
||||||
|
// Clear reference so it can be garbage collected.
|
||||||
|
var zero T
|
||||||
|
q.queue[q.r] = zero
|
||||||
|
|
||||||
|
q.r++
|
||||||
|
|
||||||
|
if q.r == q.w {
|
||||||
|
q.r = 0
|
||||||
|
q.w = 0
|
||||||
|
if len(q.queue) > minQueueLen {
|
||||||
|
q.queue = make([]T, minQueueLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return item, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue[T]) growQueue() {
|
||||||
|
desiredLen := (len(q.queue) + 1) * 3 / 2
|
||||||
|
if desiredLen < minQueueLen {
|
||||||
|
desiredLen = minQueueLen
|
||||||
|
}
|
||||||
|
|
||||||
|
newQueue := make([]T, desiredLen)
|
||||||
|
copy(newQueue, q.queue)
|
||||||
|
q.queue = newQueue
|
||||||
|
}
|
157
pgconn/pgconn.go
157
pgconn/pgconn.go
|
@ -13,9 +13,9 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
"github.com/jackc/pgx/v5/internal/nbbconn"
|
"github.com/jackc/pgx/v5/internal/nbbconn"
|
||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||||
|
@ -76,11 +76,6 @@ type PgConn struct {
|
||||||
|
|
||||||
status byte // One of connStatus* constants
|
status byte // One of connStatus* constants
|
||||||
|
|
||||||
bufferingReceive bool
|
|
||||||
bufferingReceiveMux sync.Mutex
|
|
||||||
bufferingReceiveMsg pgproto3.BackendMessage
|
|
||||||
bufferingReceiveErr error
|
|
||||||
|
|
||||||
peekedMsg pgproto3.BackendMessage
|
peekedMsg pgproto3.BackendMessage
|
||||||
|
|
||||||
// Reusable / preallocated resources
|
// Reusable / preallocated resources
|
||||||
|
@ -254,6 +249,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.conn = nbbconn.New(pgConn.conn)
|
pgConn.conn = nbbconn.New(pgConn.conn)
|
||||||
|
pgConn.contextWatcher.Unwatch() // context watcher should watch nbbconn
|
||||||
|
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
||||||
|
|
||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
|
@ -388,24 +385,6 @@ func hexMD5(s string) string {
|
||||||
return hex.EncodeToString(hash.Sum(nil))
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) signalMessage() chan struct{} {
|
|
||||||
if pgConn.bufferingReceive {
|
|
||||||
panic("BUG: signalMessage when already in progress")
|
|
||||||
}
|
|
||||||
|
|
||||||
pgConn.bufferingReceive = true
|
|
||||||
pgConn.bufferingReceiveMux.Lock()
|
|
||||||
|
|
||||||
ch := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
|
|
||||||
pgConn.bufferingReceiveMux.Unlock()
|
|
||||||
close(ch)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||||
|
@ -445,25 +424,13 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
||||||
return pgConn.peekedMsg, nil
|
return pgConn.peekedMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg pgproto3.BackendMessage
|
msg, err := pgConn.frontend.Receive()
|
||||||
var err error
|
|
||||||
if pgConn.bufferingReceive {
|
|
||||||
pgConn.bufferingReceiveMux.Lock()
|
|
||||||
msg = pgConn.bufferingReceiveMsg
|
|
||||||
err = pgConn.bufferingReceiveErr
|
|
||||||
pgConn.bufferingReceiveMux.Unlock()
|
|
||||||
pgConn.bufferingReceive = false
|
|
||||||
|
|
||||||
// If a timeout error happened in the background try the read again.
|
|
||||||
var netErr net.Error
|
|
||||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
||||||
msg, err = pgConn.frontend.Receive()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg, err = pgConn.frontend.Receive()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, nbbconn.ErrWouldBlock) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Close on anything other than timeout error - everything else is fatal
|
// Close on anything other than timeout error - everything else is fatal
|
||||||
var netErr net.Error
|
var netErr net.Error
|
||||||
isNetErr := errors.As(err, &netErr)
|
isNetErr := errors.As(err, &netErr)
|
||||||
|
@ -482,13 +449,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
||||||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
msg, err := pgConn.peekMessage()
|
msg, err := pgConn.peekMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Close on anything other than timeout error - everything else is fatal
|
|
||||||
var netErr net.Error
|
|
||||||
isNetErr := errors.As(err, &netErr)
|
|
||||||
if !(isNetErr && netErr.Timeout()) {
|
|
||||||
pgConn.asyncClose()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pgConn.peekedMsg = nil
|
pgConn.peekedMsg = nil
|
||||||
|
@ -1176,62 +1136,57 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
|
|
||||||
// Send copy to command
|
// Send copy to command
|
||||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||||
|
|
||||||
err := pgConn.frontend.Flush()
|
err := pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
return CommandTag{}, err
|
return CommandTag{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send copy data
|
err = pgConn.conn.SetReadDeadline(nbbconn.NonBlockingDeadline)
|
||||||
abortCopyChan := make(chan struct{})
|
if err != nil {
|
||||||
copyErrChan := make(chan error, 1)
|
pgConn.asyncClose()
|
||||||
signalMessageChan := pgConn.signalMessage()
|
return CommandTag{}, err
|
||||||
senderDoneChan := make(chan struct{})
|
}
|
||||||
|
nonblocking := true
|
||||||
go func() {
|
defer func() {
|
||||||
defer close(senderDoneChan)
|
if nonblocking {
|
||||||
|
pgConn.conn.SetReadDeadline(time.Time{})
|
||||||
buf := make([]byte, 0, 65536)
|
|
||||||
buf = append(buf, 'd')
|
|
||||||
sp := len(buf)
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, readErr := r.Read(buf[5:cap(buf)])
|
|
||||||
if n > 0 {
|
|
||||||
buf = buf[0 : n+5]
|
|
||||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
|
||||||
|
|
||||||
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
|
|
||||||
if writeErr != nil {
|
|
||||||
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
|
|
||||||
pgConn.conn.Close()
|
|
||||||
|
|
||||||
copyErrChan <- writeErr
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if readErr != nil {
|
|
||||||
copyErrChan <- readErr
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-abortCopyChan:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var pgErr error
|
buf := iobufpool.Get(65536)
|
||||||
var copyErr error
|
buf[0] = 'd'
|
||||||
for copyErr == nil && pgErr == nil {
|
|
||||||
select {
|
var readErr, pgErr error
|
||||||
case copyErr = <-copyErrChan:
|
for {
|
||||||
case <-signalMessageChan:
|
// Read chunk from r.
|
||||||
|
var n int
|
||||||
|
n, readErr = r.Read(buf[5:cap(buf)])
|
||||||
|
|
||||||
|
// Send chunk to PostgreSQL.
|
||||||
|
if n > 0 {
|
||||||
|
buf = buf[0 : n+5]
|
||||||
|
pgio.SetInt32(buf[1:], int32(n+4))
|
||||||
|
|
||||||
|
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
|
||||||
|
if writeErr != nil {
|
||||||
|
pgConn.asyncClose()
|
||||||
|
return CommandTag{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort loop if there was a read error.
|
||||||
|
if readErr != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read messages until error or none available.
|
||||||
|
for {
|
||||||
msg, err := pgConn.receiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, nbbconn.ErrWouldBlock) {
|
||||||
|
break
|
||||||
|
}
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
|
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
@ -1239,18 +1194,22 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgErr = ErrorResponseToPgError(msg)
|
pgErr = ErrorResponseToPgError(msg)
|
||||||
default:
|
break
|
||||||
signalMessageChan = pgConn.signalMessage()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(abortCopyChan)
|
|
||||||
<-senderDoneChan
|
|
||||||
|
|
||||||
if copyErr == io.EOF || pgErr != nil {
|
err = pgConn.conn.SetReadDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
pgConn.asyncClose()
|
||||||
|
return CommandTag{}, err
|
||||||
|
}
|
||||||
|
nonblocking = false
|
||||||
|
|
||||||
|
if readErr == io.EOF || pgErr != nil {
|
||||||
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
||||||
} else {
|
} else {
|
||||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
|
||||||
}
|
}
|
||||||
err = pgConn.frontend.Flush()
|
err = pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue