mirror of https://github.com/jackc/pgx.git
Add non-blocking IO
This eliminates an edge case that can cause a deadlock and is a prerequisite to cheaply testing connection liveness and to recoving a connection after a timeout. https://github.com/jackc/pgconn/issues/27 Squashed commit of the following: commitpull/1281/head0d7b0dddea
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 25 13:15:05 2022 -0500 Add test for non-blocking IO preventing deadlock commit79d68d23d3
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 18 18:23:24 2022 -0500 Release CopyFrom buf when done commit95a43139c7
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 18 18:22:32 2022 -0500 Avoid allocations with non-blocking write commit6b63ceee07
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 18 17:46:49 2022 -0500 Simplify iobufpool usage commit60ecdda02e
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 18 11:51:59 2022 -0500 Add true non-blocking IO commit7dd26a34a1
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 20:28:23 2022 -0500 Fix block when reading more than buffered commitafa702213f
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 20:10:23 2022 -0500 More TLS support commit51655bf8f4
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 17:46:00 2022 -0500 Steps toward TLS commit2b80beb1ed
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 13:06:29 2022 -0500 Litle more TLS support commit765b2c6e7b
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 12:29:30 2022 -0500 Add testing of TLS commit5b64432afb
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 09:48:19 2022 -0500 Introduce testVariants in prep for TLS commitecebd7b103
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 09:32:14 2022 -0500 Handle and test read of previously buffered data commit09c64d8cf3
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 09:04:48 2022 -0500 Rename nbbconn to nbconn commit73398bc67a
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 08:59:53 2022 -0500 Remove backup files commitf1df39a29d
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 08:58:05 2022 -0500 Initial passing tests commitea3cdab234
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat Jun 4 08:38:57 2022 -0500 Fix connect timeout commitca22396789
Author: Jack Christensen <jack@jackchristensen.com> Date: Thu Jun 2 19:32:55 2022 -0500 wip commit2e7b46d5d7
Author: Jack Christensen <jack@jackchristensen.com> Date: Mon May 30 08:32:43 2022 -0500 Update comments commit7d04dc5caa
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat May 28 19:43:23 2022 -0500 Fix broken test commitbf1edc77d7
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat May 28 19:40:33 2022 -0500 fixed putting wrong size bufs commit1f7a855b2e
Author: Jack Christensen <jack@jackchristensen.com> Date: Sat May 28 18:13:47 2022 -0500 initial not quite working non-blocking conn
parent
c0a4d1b9ce
commit
811d855a35
internal
pgconn
|
@ -14,26 +14,16 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// Get gets a []byte with len >= size and len <= size*2.
|
||||
// Get gets a []byte of len size with cap <= size*2.
|
||||
func Get(size int) []byte {
|
||||
i := poolIdx(size)
|
||||
i := getPoolIdx(size)
|
||||
if i >= len(pools) {
|
||||
return make([]byte, size)
|
||||
}
|
||||
return pools[i].Get().([]byte)
|
||||
return pools[i].Get().([]byte)[:size]
|
||||
}
|
||||
|
||||
// Put returns buf to the pool.
|
||||
func Put(buf []byte) {
|
||||
i := poolIdx(len(buf))
|
||||
if i >= len(pools) {
|
||||
return
|
||||
}
|
||||
|
||||
pools[i].Put(buf)
|
||||
}
|
||||
|
||||
func poolIdx(size int) int {
|
||||
func getPoolIdx(size int) int {
|
||||
size--
|
||||
size >>= minPoolExpOf2
|
||||
i := 0
|
||||
|
@ -44,3 +34,24 @@ func poolIdx(size int) int {
|
|||
|
||||
return i
|
||||
}
|
||||
|
||||
// Put returns buf to the pool.
|
||||
func Put(buf []byte) {
|
||||
i := putPoolIdx(cap(buf))
|
||||
if i < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
pools[i].Put(buf)
|
||||
}
|
||||
|
||||
func putPoolIdx(size int) int {
|
||||
minPoolSize := 1 << minPoolExpOf2
|
||||
for i := range pools {
|
||||
if size == minPoolSize<<i {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ func TestPoolIdx(t *testing.T) {
|
|||
{size: 8388609, expected: 16},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
idx := poolIdx(tt.size)
|
||||
idx := getPoolIdx(tt.size)
|
||||
assert.Equalf(t, tt.expected, idx, "size: %d", tt.size)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,31 +5,74 @@ import (
|
|||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
func TestGetCap(t *testing.T) {
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
expectedLen int
|
||||
expectedCap int
|
||||
}{
|
||||
{requestedLen: 0, expectedLen: 256},
|
||||
{requestedLen: 128, expectedLen: 256},
|
||||
{requestedLen: 255, expectedLen: 256},
|
||||
{requestedLen: 256, expectedLen: 256},
|
||||
{requestedLen: 257, expectedLen: 512},
|
||||
{requestedLen: 511, expectedLen: 512},
|
||||
{requestedLen: 512, expectedLen: 512},
|
||||
{requestedLen: 513, expectedLen: 1024},
|
||||
{requestedLen: 1023, expectedLen: 1024},
|
||||
{requestedLen: 1024, expectedLen: 1024},
|
||||
{requestedLen: 33554431, expectedLen: 33554432},
|
||||
{requestedLen: 33554432, expectedLen: 33554432},
|
||||
{requestedLen: 0, expectedCap: 256},
|
||||
{requestedLen: 128, expectedCap: 256},
|
||||
{requestedLen: 255, expectedCap: 256},
|
||||
{requestedLen: 256, expectedCap: 256},
|
||||
{requestedLen: 257, expectedCap: 512},
|
||||
{requestedLen: 511, expectedCap: 512},
|
||||
{requestedLen: 512, expectedCap: 512},
|
||||
{requestedLen: 513, expectedCap: 1024},
|
||||
{requestedLen: 1023, expectedCap: 1024},
|
||||
{requestedLen: 1024, expectedCap: 1024},
|
||||
{requestedLen: 33554431, expectedCap: 33554432},
|
||||
{requestedLen: 33554432, expectedCap: 33554432},
|
||||
|
||||
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||
{requestedLen: 33554433, expectedLen: 33554433},
|
||||
{requestedLen: 33554433, expectedCap: 33554433},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
buf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedLen, len(buf), "requestedLen: %d", tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
|
||||
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
|
||||
putBuf := make([]byte, putBufSize)
|
||||
iobufpool.Put(putBuf)
|
||||
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
expectedCap int
|
||||
}{
|
||||
{requestedLen: 0, expectedCap: 256},
|
||||
{requestedLen: 128, expectedCap: 256},
|
||||
{requestedLen: 255, expectedCap: 256},
|
||||
{requestedLen: 256, expectedCap: 256},
|
||||
{requestedLen: 257, expectedCap: 512},
|
||||
{requestedLen: 511, expectedCap: 512},
|
||||
{requestedLen: 512, expectedCap: 512},
|
||||
{requestedLen: 513, expectedCap: 1024},
|
||||
{requestedLen: 1023, expectedCap: 1024},
|
||||
{requestedLen: 1024, expectedCap: 1024},
|
||||
{requestedLen: 33554431, expectedCap: 33554432},
|
||||
{requestedLen: 33554432, expectedCap: 33554432},
|
||||
|
||||
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||
{requestedLen: 33554433, expectedCap: 33554433},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
getBuf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutGetBufferReuse(t *testing.T) {
|
||||
buf := iobufpool.Get(4)
|
||||
buf[0] = 1
|
||||
iobufpool.Put(buf)
|
||||
buf = iobufpool.Get(4)
|
||||
require.Equal(t, byte(1), buf[0])
|
||||
}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
package nbconn
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const minBufferQueueLen = 8
|
||||
|
||||
type bufferQueue struct {
|
||||
lock sync.Mutex
|
||||
queue [][]byte
|
||||
r, w int
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) pushBack(buf []byte) {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.w >= len(bq.queue) {
|
||||
bq.growQueue()
|
||||
}
|
||||
bq.queue[bq.w] = buf
|
||||
bq.w++
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) pushFront(buf []byte) {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.w >= len(bq.queue) {
|
||||
bq.growQueue()
|
||||
}
|
||||
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
|
||||
bq.queue[bq.r] = buf
|
||||
bq.w++
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) popFront() []byte {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.r == bq.w {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := bq.queue[bq.r]
|
||||
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
|
||||
bq.r++
|
||||
|
||||
if bq.r == bq.w {
|
||||
bq.r = 0
|
||||
bq.w = 0
|
||||
if len(bq.queue) > minBufferQueueLen {
|
||||
bq.queue = make([][]byte, minBufferQueueLen)
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) growQueue() {
|
||||
desiredLen := (len(bq.queue) + 1) * 3 / 2
|
||||
if desiredLen < minBufferQueueLen {
|
||||
desiredLen = minBufferQueueLen
|
||||
}
|
||||
|
||||
newQueue := make([][]byte, desiredLen)
|
||||
copy(newQueue, bq.queue)
|
||||
bq.queue = newQueue
|
||||
}
|
|
@ -0,0 +1,513 @@
|
|||
// Package nbconn implements a non-blocking net.Conn wrapper.
|
||||
//
|
||||
// It is designed to solve three problems.
|
||||
//
|
||||
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
|
||||
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
|
||||
//
|
||||
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
|
||||
//
|
||||
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||
package nbconn
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
var errClosed = errors.New("closed")
|
||||
var ErrWouldBlock = new(wouldBlockError)
|
||||
|
||||
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
||||
|
||||
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
|
||||
// mode.
|
||||
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
|
||||
|
||||
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
|
||||
// ignore all future calls.
|
||||
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
|
||||
|
||||
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
|
||||
type wouldBlockError struct{}
|
||||
|
||||
func (*wouldBlockError) Error() string {
|
||||
return "would block"
|
||||
}
|
||||
|
||||
func (*wouldBlockError) Timeout() bool { return true }
|
||||
func (*wouldBlockError) Temporary() bool { return true }
|
||||
|
||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the
|
||||
// underlying connection.
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
type NetConn struct {
|
||||
conn net.Conn
|
||||
rawConn syscall.RawConn
|
||||
|
||||
readQueue bufferQueue
|
||||
writeQueue bufferQueue
|
||||
|
||||
readFlushLock sync.Mutex
|
||||
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
|
||||
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
|
||||
nonblockWriteBuf []byte
|
||||
nonblockWriteErr error
|
||||
nonblockWriteN int
|
||||
|
||||
readDeadlineLock sync.Mutex
|
||||
readDeadline time.Time
|
||||
readNonblocking bool
|
||||
|
||||
writeDeadlineLock sync.Mutex
|
||||
writeDeadline time.Time
|
||||
|
||||
// Only access with atomics
|
||||
closed int64 // 0 = not closed, 1 = closed
|
||||
}
|
||||
|
||||
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||
nc := &NetConn{
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
if !fakeNonBlockingIO {
|
||||
if sc, ok := conn.(syscall.Conn); ok {
|
||||
if rawConn, err := sc.SyscallConn(); err == nil {
|
||||
nc.rawConn = rawConn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (c *NetConn) Read(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
|
||||
err = c.flush()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for n < len(b) {
|
||||
buf := c.readQueue.popFront()
|
||||
if buf == nil {
|
||||
break
|
||||
}
|
||||
copiedN := copy(b[n:], buf)
|
||||
if copiedN < len(buf) {
|
||||
buf = buf[copiedN:]
|
||||
c.readQueue.pushFront(buf)
|
||||
} else {
|
||||
iobufpool.Put(buf)
|
||||
}
|
||||
n += copiedN
|
||||
}
|
||||
|
||||
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
|
||||
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var readNonblocking bool
|
||||
c.readDeadlineLock.Lock()
|
||||
readNonblocking = c.readNonblocking
|
||||
c.readDeadlineLock.Unlock()
|
||||
|
||||
var readN int
|
||||
if readNonblocking {
|
||||
readN, err = c.nonblockingRead(b[n:])
|
||||
} else {
|
||||
readN, err = c.conn.Read(b[n:])
|
||||
}
|
||||
n += readN
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
||||
// closed. Call Flush to actually write to the underlying connection.
|
||||
func (c *NetConn) Write(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
||||
buf := iobufpool.Get(len(b))
|
||||
copy(buf, b)
|
||||
c.writeQueue.pushBack(buf)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *NetConn) Close() (err error) {
|
||||
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
|
||||
if !swapped {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := c.conn.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
err = c.flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *NetConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||
func (c *NetConn) SetDeadline(t time.Time) error {
|
||||
err := c.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
|
||||
func (c *NetConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
if c.readDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.readDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
if t == NonBlockingDeadline {
|
||||
c.readNonblocking = true
|
||||
t = time.Time{}
|
||||
} else {
|
||||
c.readNonblocking = false
|
||||
}
|
||||
|
||||
c.readDeadline = t
|
||||
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *NetConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
if c.writeDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.writeDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
c.writeDeadline = t
|
||||
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *NetConn) Flush() error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
return c.flush()
|
||||
}
|
||||
|
||||
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
|
||||
func (c *NetConn) flush() error {
|
||||
var stopChan chan struct{}
|
||||
var errChan chan error
|
||||
|
||||
defer func() {
|
||||
if stopChan != nil {
|
||||
select {
|
||||
case stopChan <- struct{}{}:
|
||||
case <-errChan:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
|
||||
remainingBuf := buf
|
||||
for len(remainingBuf) > 0 {
|
||||
n, err := c.nonblockingWrite(remainingBuf)
|
||||
remainingBuf = remainingBuf[n:]
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrWouldBlock) {
|
||||
buf = buf[:len(remainingBuf)]
|
||||
copy(buf, remainingBuf)
|
||||
c.writeQueue.pushFront(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Writing was blocked. Reading might unblock it.
|
||||
if stopChan == nil {
|
||||
stopChan, errChan = c.bufferNonblockingRead()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
stopChan = nil
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
iobufpool.Put(buf)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||
stopChan = make(chan struct{})
|
||||
errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
buf := iobufpool.Get(8 * 1024)
|
||||
n, err := c.nonblockingRead(buf)
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
c.readQueue.pushBack(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrWouldBlock) {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stopChan, errChan
|
||||
}
|
||||
|
||||
func (c *NetConn) isClosed() bool {
|
||||
closed := atomic.LoadInt64(&c.closed)
|
||||
return closed == 1
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
} else {
|
||||
return c.realNonblockingWrite(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
|
||||
err = c.conn.SetWriteDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.conn.SetWriteDeadline(c.writeDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
err = ErrWouldBlock
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
|
||||
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.nonblockWriteBuf = b
|
||||
c.nonblockWriteN = 0
|
||||
c.nonblockWriteErr = nil
|
||||
err = c.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
|
||||
return true
|
||||
})
|
||||
n = c.nonblockWriteN
|
||||
if err == nil && c.nonblockWriteErr != nil {
|
||||
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = c.nonblockWriteErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingRead(b)
|
||||
} else {
|
||||
return c.realNonblockingRead(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
|
||||
err = c.conn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.conn.SetReadDeadline(c.readDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
err = ErrWouldBlock
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||
var funcErr error
|
||||
err = c.rawConn.Read(func(fd uintptr) (done bool) {
|
||||
n, funcErr = syscall.Read(int(fd), b)
|
||||
return true
|
||||
})
|
||||
if err == nil && funcErr != nil {
|
||||
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = funcErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// syscall.Conn is interface
|
||||
|
||||
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||
//
|
||||
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
|
||||
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
|
||||
// *TLSConn is returned.
|
||||
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
|
||||
tc := tls.Client(conn, config)
|
||||
err := tc.Handshake()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure last written part of Handshake is actually sent.
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TLSConn{
|
||||
tlsConn: tc,
|
||||
nbConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
|
||||
// tls.Conn.
|
||||
type TLSConn struct {
|
||||
tlsConn *tls.Conn
|
||||
nbConn *NetConn
|
||||
}
|
||||
|
||||
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||
|
||||
func (tc *TLSConn) Close() error {
|
||||
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
|
||||
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
|
||||
// own 5 second deadline then make all set deadlines no-op.
|
||||
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
|
||||
|
||||
return tc.tlsConn.Close()
|
||||
}
|
||||
|
||||
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
|
||||
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
|
||||
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }
|
|
@ -0,0 +1,554 @@
|
|||
package nbconn_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test keys generated with:
|
||||
//
|
||||
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
|
||||
|
||||
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
|
||||
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
|
||||
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
|
||||
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
|
||||
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
|
||||
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
|
||||
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
|
||||
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
|
||||
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
|
||||
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
|
||||
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
|
||||
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
|
||||
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
|
||||
E7ZYmaKTMOhvkg==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
|
||||
// source code.
|
||||
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
|
||||
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
|
||||
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
|
||||
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
|
||||
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
|
||||
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
|
||||
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
|
||||
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
|
||||
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
|
||||
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
|
||||
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
|
||||
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
|
||||
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
|
||||
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
|
||||
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
|
||||
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
|
||||
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
|
||||
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
|
||||
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
|
||||
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
|
||||
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
|
||||
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
|
||||
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
|
||||
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
|
||||
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
||||
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
||||
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
||||
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||
useTLS bool
|
||||
fakeNonBlockingIO bool
|
||||
}{
|
||||
{
|
||||
name: "Pipe",
|
||||
makeConns: makePipeConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
local, remote := tt.makeConns(t)
|
||||
|
||||
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
|
||||
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
|
||||
// uses remote it may be garbage collected leading to the connection being closed.
|
||||
defer local.Close()
|
||||
defer remote.Close()
|
||||
|
||||
var conn nbconn.Conn
|
||||
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
|
||||
|
||||
if tt.useTLS {
|
||||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsServer := tls.Server(remote, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
serverTLSHandshakeChan := make(chan error)
|
||||
go func() {
|
||||
err := tlsServer.Handshake()
|
||||
serverTLSHandshakeChan <- err
|
||||
}()
|
||||
|
||||
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
|
||||
require.NoError(t, err)
|
||||
conn = tlsConn
|
||||
|
||||
err = <-serverTLSHandshakeChan
|
||||
require.NoError(t, err)
|
||||
remote = tlsServer
|
||||
} else {
|
||||
conn = netConn
|
||||
}
|
||||
|
||||
f(t, conn, remote)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
|
||||
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
|
||||
func makePipeConns(t *testing.T) (local, remote net.Conn) {
|
||||
local, remote = net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
local.Close()
|
||||
remote.Close()
|
||||
})
|
||||
|
||||
return local, remote
|
||||
}
|
||||
|
||||
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
|
||||
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
type acceptResultT struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
acceptChan := make(chan acceptResultT)
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
acceptChan <- acceptResultT{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
local, err = net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
acceptResult := <-acceptChan
|
||||
require.NoError(t, acceptResult.err)
|
||||
|
||||
remote = acceptResult.conn
|
||||
|
||||
return local, remote
|
||||
}
|
||||
|
||||
func TestWriteIsBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
// net.Pipe is synchronous so the Write would block if not buffered.
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := conn.Flush()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err = remote.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetWriteDeadline(time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err := remote.Read(readBuf)
|
||||
errChan <- err
|
||||
|
||||
_, err = remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("okay"), readBuf)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err := remote.Read(readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
|
||||
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
|
||||
// large values.
|
||||
func TestInternalNonBlockingWrite(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
remoteWriteBuf := make([]byte, deadlockSize)
|
||||
_, err := remote.Write(remoteWriteBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Flush()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, err := conn.Read(buf)
|
||||
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
|
||||
require.EqualValues(t, 0, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = conn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 5)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 5, n)
|
||||
require.Equal(t, []byte("alpha"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 10)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 5, n)
|
||||
require.Equal(t, []byte("alpha"), readBuf[:n])
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 2)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, n)
|
||||
require.Equal(t, []byte("al"), readBuf)
|
||||
|
||||
readBuf = make([]byte, 3)
|
||||
n, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("pha"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = remote.Write([]byte("beta"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
flushCompleteChan := make(chan struct{})
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-flushCompleteChan
|
||||
|
||||
_, err = remote.Write([]byte("beta"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
close(flushCompleteChan)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
|
||||
err = <-errChan
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
188
pgconn/pgconn.go
188
pgconn/pgconn.go
|
@ -13,9 +13,10 @@ import (
|
|||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
|
@ -75,11 +76,6 @@ type PgConn struct {
|
|||
|
||||
status byte // One of connStatus* constants
|
||||
|
||||
bufferingReceive bool
|
||||
bufferingReceiveMux sync.Mutex
|
||||
bufferingReceiveMsg pgproto3.BackendMessage
|
||||
bufferingReceiveErr error
|
||||
|
||||
peekedMsg pgproto3.BackendMessage
|
||||
|
||||
// Reusable / preallocated resources
|
||||
|
@ -234,13 +230,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
}
|
||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||
}
|
||||
netConn = nbconn.NewNetConn(netConn, false)
|
||||
|
||||
pgConn.conn = netConn
|
||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
||||
if fallbackConfig.TLSConfig != nil {
|
||||
tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
|
||||
tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig)
|
||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||
if err != nil {
|
||||
netConn.Close()
|
||||
|
@ -356,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
|||
)
|
||||
}
|
||||
|
||||
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -371,7 +368,12 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
|||
return nil, errors.New("server refused TLS connection")
|
||||
}
|
||||
|
||||
return tls.Client(conn, tlsConfig), nil
|
||||
tlsConn, err := nbconn.TLSClient(conn, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||
|
@ -385,24 +387,6 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||
if pgConn.bufferingReceive {
|
||||
panic("BUG: signalMessage when already in progress")
|
||||
}
|
||||
|
||||
pgConn.bufferingReceive = true
|
||||
pgConn.bufferingReceiveMux.Lock()
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
|
||||
pgConn.bufferingReceiveMux.Unlock()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||
|
@ -442,25 +426,13 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
|||
return pgConn.peekedMsg, nil
|
||||
}
|
||||
|
||||
var msg pgproto3.BackendMessage
|
||||
var err error
|
||||
if pgConn.bufferingReceive {
|
||||
pgConn.bufferingReceiveMux.Lock()
|
||||
msg = pgConn.bufferingReceiveMsg
|
||||
err = pgConn.bufferingReceiveErr
|
||||
pgConn.bufferingReceiveMux.Unlock()
|
||||
pgConn.bufferingReceive = false
|
||||
|
||||
// If a timeout error happened in the background try the read again.
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
msg, err = pgConn.frontend.Receive()
|
||||
}
|
||||
} else {
|
||||
msg, err = pgConn.frontend.Receive()
|
||||
}
|
||||
msg, err := pgConn.frontend.Receive()
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Close on anything other than timeout error - everything else is fatal
|
||||
var netErr net.Error
|
||||
isNetErr := errors.As(err, &netErr)
|
||||
|
@ -479,13 +451,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
|||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
msg, err := pgConn.peekMessage()
|
||||
if err != nil {
|
||||
// Close on anything other than timeout error - everything else is fatal
|
||||
var netErr net.Error
|
||||
isNetErr := errors.As(err, &netErr)
|
||||
if !(isNetErr && netErr.Timeout()) {
|
||||
pgConn.asyncClose()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
pgConn.peekedMsg = nil
|
||||
|
@ -1173,62 +1138,58 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
|
||||
// Send copy to command
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
|
||||
err := pgConn.frontend.Flush()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
}
|
||||
|
||||
// Send copy data
|
||||
abortCopyChan := make(chan struct{})
|
||||
copyErrChan := make(chan error, 1)
|
||||
signalMessageChan := pgConn.signalMessage()
|
||||
senderDoneChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(senderDoneChan)
|
||||
|
||||
buf := make([]byte, 0, 65536)
|
||||
buf = append(buf, 'd')
|
||||
sp := len(buf)
|
||||
|
||||
for {
|
||||
n, readErr := r.Read(buf[5:cap(buf)])
|
||||
if n > 0 {
|
||||
buf = buf[0 : n+5]
|
||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||
|
||||
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
|
||||
if writeErr != nil {
|
||||
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
|
||||
pgConn.conn.Close()
|
||||
|
||||
copyErrChan <- writeErr
|
||||
return
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
copyErrChan <- readErr
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-abortCopyChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
}
|
||||
nonblocking := true
|
||||
defer func() {
|
||||
if nonblocking {
|
||||
pgConn.conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
}()
|
||||
|
||||
var pgErr error
|
||||
var copyErr error
|
||||
for copyErr == nil && pgErr == nil {
|
||||
select {
|
||||
case copyErr = <-copyErrChan:
|
||||
case <-signalMessageChan:
|
||||
buf := iobufpool.Get(65536)
|
||||
defer iobufpool.Put(buf)
|
||||
buf[0] = 'd'
|
||||
|
||||
var readErr, pgErr error
|
||||
for pgErr == nil {
|
||||
// Read chunk from r.
|
||||
var n int
|
||||
n, readErr = r.Read(buf[5:cap(buf)])
|
||||
|
||||
// Send chunk to PostgreSQL.
|
||||
if n > 0 {
|
||||
buf = buf[0 : n+5]
|
||||
pgio.SetInt32(buf[1:], int32(n+4))
|
||||
|
||||
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
|
||||
if writeErr != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Abort loop if there was a read error.
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Read messages until error or none available.
|
||||
for pgErr == nil {
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
if errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
break
|
||||
}
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
@ -1236,18 +1197,22 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = ErrorResponseToPgError(msg)
|
||||
default:
|
||||
signalMessageChan = pgConn.signalMessage()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
close(abortCopyChan)
|
||||
<-senderDoneChan
|
||||
|
||||
if copyErr == io.EOF || pgErr != nil {
|
||||
err = pgConn.conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
}
|
||||
nonblocking = false
|
||||
|
||||
if readErr == io.EOF || pgErr != nil {
|
||||
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
||||
} else {
|
||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
|
||||
}
|
||||
err = pgConn.frontend.Flush()
|
||||
if err != nil {
|
||||
|
@ -1603,18 +1568,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
|
||||
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
|
||||
// A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is
|
||||
// closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication
|
||||
// channel to relay the error back. The practical effect of this is that the underlying Write error is not reported.
|
||||
// The error the code reading the batch results receives will be a closed connection error.
|
||||
//
|
||||
// See https://github.com/jackc/pgx/issues/374.
|
||||
go func() {
|
||||
_, err := pgConn.conn.Write(batch.buf)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
}
|
||||
}()
|
||||
_, err := pgConn.conn.Write(batch.buf)
|
||||
if err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = err
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
return multiResult
|
||||
}
|
||||
|
|
|
@ -1849,13 +1849,14 @@ func TestConnCancelRequest(t *testing.T) {
|
|||
|
||||
multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)")
|
||||
|
||||
// This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a
|
||||
// response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a
|
||||
// few milliseconds.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
go func() {
|
||||
// The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent.
|
||||
// Once Flush is available this could use that instead.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
err = pgConn.CancelRequest(context.Background())
|
||||
require.NoError(t, err)
|
||||
err = pgConn.CancelRequest(context.Background())
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
for multiResult.NextResult() {
|
||||
}
|
||||
|
@ -2027,6 +2028,36 @@ func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgconn/issues/27
|
||||
func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), "set client_min_messages = debug5").ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The actual contents of this test aren't important. What's important is a large amount of data to be written and
|
||||
// because of client_min_messages = debug5 the server will return a large amount of data.
|
||||
|
||||
paramCount := math.MaxUint16
|
||||
params := make([]string, 0, paramCount)
|
||||
args := make([][]byte, 0, paramCount)
|
||||
for i := 0; i < paramCount; i++ {
|
||||
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||
args = append(args, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
sql := "values" + strings.Join(params, ", ")
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
require.Len(t, result.Rows, paramCount)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func Example() {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue