diff --git a/internal/iobufpool/iobufpool.go b/internal/iobufpool/iobufpool.go index 52c52f45..9e55c435 100644 --- a/internal/iobufpool/iobufpool.go +++ b/internal/iobufpool/iobufpool.go @@ -14,26 +14,16 @@ func init() { } } -// Get gets a []byte with len >= size and len <= size*2. +// Get gets a []byte of len size with cap <= size*2. func Get(size int) []byte { - i := poolIdx(size) + i := getPoolIdx(size) if i >= len(pools) { return make([]byte, size) } - return pools[i].Get().([]byte) + return pools[i].Get().([]byte)[:size] } -// Put returns buf to the pool. -func Put(buf []byte) { - i := poolIdx(len(buf)) - if i >= len(pools) { - return - } - - pools[i].Put(buf) -} - -func poolIdx(size int) int { +func getPoolIdx(size int) int { size-- size >>= minPoolExpOf2 i := 0 @@ -44,3 +34,24 @@ func poolIdx(size int) int { return i } + +// Put returns buf to the pool. +func Put(buf []byte) { + i := putPoolIdx(cap(buf)) + if i < 0 { + return + } + + pools[i].Put(buf) +} + +func putPoolIdx(size int) int { + minPoolSize := 1 << minPoolExpOf2 + for i := range pools { + if size == minPoolSize<= len(bq.queue) { + bq.growQueue() + } + bq.queue[bq.w] = buf + bq.w++ +} + +func (bq *bufferQueue) pushFront(buf []byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) + bq.queue[bq.r] = buf + bq.w++ +} + +func (bq *bufferQueue) popFront() []byte { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.r == bq.w { + return nil + } + + buf := bq.queue[bq.r] + bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. + bq.r++ + + if bq.r == bq.w { + bq.r = 0 + bq.w = 0 + if len(bq.queue) > minBufferQueueLen { + bq.queue = make([][]byte, minBufferQueueLen) + } + } + + return buf +} + +func (bq *bufferQueue) growQueue() { + desiredLen := (len(bq.queue) + 1) * 3 / 2 + if desiredLen < minBufferQueueLen { + desiredLen = minBufferQueueLen + } + + newQueue := make([][]byte, desiredLen) + copy(newQueue, bq.queue) + bq.queue = newQueue +} diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go new file mode 100644 index 00000000..00d0e420 --- /dev/null +++ b/internal/nbconn/nbconn.go @@ -0,0 +1,513 @@ +// Package nbconn implements a non-blocking net.Conn wrapper. +// +// It is designed to solve three problems. +// +// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all +// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. +// +// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. +// +// The third is to efficiently check if a connection has been closed via a non-blocking read. +package nbconn + +import ( + "crypto/tls" + "errors" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +var errClosed = errors.New("closed") +var ErrWouldBlock = new(wouldBlockError) + +const fakeNonblockingWaitDuration = 100 * time.Millisecond + +// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read +// mode. +var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) + +// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to +// ignore all future calls. +var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) + +// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. +type wouldBlockError struct{} + +func (*wouldBlockError) Error() string { + return "would block" +} + +func (*wouldBlockError) Timeout() bool { return true } +func (*wouldBlockError) Temporary() bool { return true } + +// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the +// underlying connection. +type Conn interface { + net.Conn + Flush() error +} + +// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. +type NetConn struct { + conn net.Conn + rawConn syscall.RawConn + + readQueue bufferQueue + writeQueue bufferQueue + + readFlushLock sync.Mutex + // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the + // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. + nonblockWriteBuf []byte + nonblockWriteErr error + nonblockWriteN int + + readDeadlineLock sync.Mutex + readDeadline time.Time + readNonblocking bool + + writeDeadlineLock sync.Mutex + writeDeadline time.Time + + // Only access with atomics + closed int64 // 0 = not closed, 1 = closed +} + +func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { + nc := &NetConn{ + conn: conn, + } + + if !fakeNonBlockingIO { + if sc, ok := conn.(syscall.Conn); ok { + if rawConn, err := sc.SyscallConn(); err == nil { + nc.rawConn = rawConn + } + } + } + + return nc +} + +// Read implements io.Reader. +func (c *NetConn) Read(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + + err = c.flush() + if err != nil { + return 0, err + } + + for n < len(b) { + buf := c.readQueue.popFront() + if buf == nil { + break + } + copiedN := copy(b[n:], buf) + if copiedN < len(buf) { + buf = buf[copiedN:] + c.readQueue.pushFront(buf) + } else { + iobufpool.Put(buf) + } + n += copiedN + } + + // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to + // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. + if n > 0 { + return n, nil + } + + var readNonblocking bool + c.readDeadlineLock.Lock() + readNonblocking = c.readNonblocking + c.readDeadlineLock.Unlock() + + var readN int + if readNonblocking { + readN, err = c.nonblockingRead(b[n:]) + } else { + readN, err = c.conn.Read(b[n:]) + } + n += readN + return n, err +} + +// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is +// closed. Call Flush to actually write to the underlying connection. +func (c *NetConn) Write(b []byte) (n int, err error) { + if c.isClosed() { + return 0, errClosed + } + + buf := iobufpool.Get(len(b)) + copy(buf, b) + c.writeQueue.pushBack(buf) + return len(b), nil +} + +func (c *NetConn) Close() (err error) { + swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) + if !swapped { + return errClosed + } + + defer func() { + closeErr := c.conn.Close() + if err == nil { + err = closeErr + } + }() + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + err = c.flush() + if err != nil { + return err + } + + return nil +} + +func (c *NetConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *NetConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). +func (c *NetConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. +func (c *NetConn) SetReadDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + if c.readDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.readDeadline = t + return nil + } + + if t == NonBlockingDeadline { + c.readNonblocking = true + t = time.Time{} + } else { + c.readNonblocking = false + } + + c.readDeadline = t + + return c.conn.SetReadDeadline(t) +} + +func (c *NetConn) SetWriteDeadline(t time.Time) error { + if c.isClosed() { + return errClosed + } + + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + if c.writeDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.writeDeadline = t + return nil + } + + c.writeDeadline = t + + return c.conn.SetWriteDeadline(t) +} + +func (c *NetConn) Flush() error { + if c.isClosed() { + return errClosed + } + + c.readFlushLock.Lock() + defer c.readFlushLock.Unlock() + return c.flush() +} + +// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. +func (c *NetConn) flush() error { + var stopChan chan struct{} + var errChan chan error + + defer func() { + if stopChan != nil { + select { + case stopChan <- struct{}{}: + case <-errChan: + } + } + }() + + for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { + remainingBuf := buf + for len(remainingBuf) > 0 { + n, err := c.nonblockingWrite(remainingBuf) + remainingBuf = remainingBuf[n:] + if err != nil { + if !errors.Is(err, ErrWouldBlock) { + buf = buf[:len(remainingBuf)] + copy(buf, remainingBuf) + c.writeQueue.pushFront(buf) + return err + } + + // Writing was blocked. Reading might unblock it. + if stopChan == nil { + stopChan, errChan = c.bufferNonblockingRead() + } + + select { + case err := <-errChan: + stopChan = nil + return err + default: + } + + } + } + iobufpool.Put(buf) + } + + return nil +} + +func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { + stopChan = make(chan struct{}) + errChan = make(chan error, 1) + + go func() { + for { + buf := iobufpool.Get(8 * 1024) + n, err := c.nonblockingRead(buf) + if n > 0 { + buf = buf[:n] + c.readQueue.pushBack(buf) + } + + if err != nil { + if !errors.Is(err, ErrWouldBlock) { + errChan <- err + return + } + } + + select { + case <-stopChan: + return + default: + } + } + }() + + return stopChan, errChan +} + +func (c *NetConn) isClosed() bool { + closed := atomic.LoadInt64(&c.closed) + return closed == 1 +} + +func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingWrite(b) + } else { + return c.realNonblockingWrite(b) + } +} + +func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { + c.writeDeadlineLock.Lock() + defer c.writeDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWaitDuration) + if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { + err = c.conn.SetWriteDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetWriteDeadline(c.writeDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Write(b) +} + +// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + c.nonblockWriteBuf = b + c.nonblockWriteN = 0 + c.nonblockWriteErr = nil + err = c.rawConn.Write(func(fd uintptr) (done bool) { + c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) + return true + }) + n = c.nonblockWriteN + if err == nil && c.nonblockWriteErr != nil { + if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = c.nonblockWriteErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { + if c.rawConn == nil { + return c.fakeNonblockingRead(b) + } else { + return c.realNonblockingRead(b) + } +} + +func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { + c.readDeadlineLock.Lock() + defer c.readDeadlineLock.Unlock() + + deadline := time.Now().Add(fakeNonblockingWaitDuration) + if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { + err = c.conn.SetReadDeadline(deadline) + if err != nil { + return 0, err + } + defer func() { + // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. + c.conn.SetReadDeadline(c.readDeadline) + + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrWouldBlock + } + } + }() + } + + return c.conn.Read(b) +} + +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + var funcErr error + err = c.rawConn.Read(func(fd uintptr) (done bool) { + n, funcErr = syscall.Read(int(fd), b) + return true + }) + if err == nil && funcErr != nil { + if errors.Is(funcErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = funcErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + +// syscall.Conn is interface + +// TLSClient establishes a TLS connection as a client over conn using config. +// +// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby +// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the +// *TLSConn is returned. +func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { + tc := tls.Client(conn, config) + err := tc.Handshake() + if err != nil { + return nil, err + } + + // Ensure last written part of Handshake is actually sent. + err = conn.Flush() + if err != nil { + return nil, err + } + + return &TLSConn{ + tlsConn: tc, + nbConn: conn, + }, nil +} + +// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a +// tls.Conn. +type TLSConn struct { + tlsConn *tls.Conn + nbConn *NetConn +} + +func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } +func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } +func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } +func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } + +func (tc *TLSConn) Close() error { + // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then + // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our + // own 5 second deadline then make all set deadlines no-op. + tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) + tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) + + return tc.tlsConn.Close() +} + +func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } +func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } +func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go new file mode 100644 index 00000000..2db47039 --- /dev/null +++ b/internal/nbconn/nbconn_test.go @@ -0,0 +1,554 @@ +package nbconn_test + +import ( + "crypto/tls" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/nbconn" + "github.com/stretchr/testify/require" +) + +// Test keys generated with: +// +// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost' + +var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE----- +MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls +b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ +BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5 +yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT +caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT +0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW +c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v +7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg +Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw +HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g +TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk +D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB +hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y +E7ZYmaKTMOhvkg== +-----END CERTIFICATE-----`) + +// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in +// source code. +var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny +k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+ +fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px +N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav +IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM +4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX +IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8 +TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL +CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ +/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn +lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I +Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9 +YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp +RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq +MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd +3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE +Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0 +TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA +riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr +IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu +nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk +WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc +Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77 +DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD +pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG +2qWm8jTPeDC3sq+67s2oojHf+Q== +-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY")) + +func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) { + for _, tt := range []struct { + name string + makeConns func(t *testing.T) (local, remote net.Conn) + useTLS bool + fakeNonBlockingIO bool + }{ + { + name: "Pipe", + makeConns: makePipeConns, + useTLS: false, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: true, + }, + { + name: "TLS over TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: false, + }, + { + name: "TLS over TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + local, remote := tt.makeConns(t) + + // Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get + // garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never + // uses remote it may be garbage collected leading to the connection being closed. + defer local.Close() + defer remote.Close() + + var conn nbconn.Conn + netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO) + + if tt.useTLS { + cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) + require.NoError(t, err) + + tlsServer := tls.Server(remote, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + serverTLSHandshakeChan := make(chan error) + go func() { + err := tlsServer.Handshake() + serverTLSHandshakeChan <- err + }() + + tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err) + conn = tlsConn + + err = <-serverTLSHandshakeChan + require.NoError(t, err) + remote = tlsServer + } else { + conn = netConn + } + + f(t, conn, remote) + }) + } +} + +// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is +// useful for testing an exact sequence of reads and writes with the underlying connection blocking. +func makePipeConns(t *testing.T) (local, remote net.Conn) { + local, remote = net.Pipe() + t.Cleanup(func() { + local.Close() + remote.Close() + }) + + return local, remote +} + +// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost. +func makeTCPConns(t *testing.T) (local, remote net.Conn) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + type acceptResultT struct { + conn net.Conn + err error + } + acceptChan := make(chan acceptResultT) + + go func() { + conn, err := ln.Accept() + acceptChan <- acceptResultT{conn: conn, err: err} + }() + + local, err = net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + acceptResult := <-acceptChan + require.NoError(t, acceptResult.err) + + remote = acceptResult.conn + + return local, remote +} + +func TestWriteIsBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + // net.Pipe is synchronous so the Write would block if not buffered. + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + err := conn.Flush() + errChan <- err + }() + + readBuf := make([]byte, len(writeBuf)) + _, err = remote.Read(readBuf) + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.SetWriteDeadline(time.Now()) + require.NoError(t, err) + + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + +func TestReadFlushesWriteBuffer(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 2) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + + _, err = remote.Write([]byte("okay")) + errChan <- err + }() + + readBuf := make([]byte, 4) + _, err = conn.Read(readBuf) + require.NoError(t, err) + require.Equal(t, []byte("okay"), readBuf) + + require.NoError(t, <-errChan) + require.NoError(t, <-errChan) + }) +} + +func TestCloseFlushesWriteBuffer(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + }() + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with +// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing +// large values. +func TestInternalNonBlockingWrite(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + errChan := make(chan error, 1) + go func() { + remoteWriteBuf := make([]byte, deadlockSize) + _, err := remote.Write(remoteWriteBuf) + if err != nil { + errChan <- err + return + } + + readBuf := make([]byte, deadlockSize) + _, err = io.ReadFull(remote, readBuf) + errChan <- err + }() + + readBuf := make([]byte, deadlockSize) + _, err = conn.Read(readBuf) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestInternalNonBlockingWriteWithDeadline(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err) + + err = conn.Flush() + require.Error(t, err) + }) +} + +func TestNonBlockingRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.SetReadDeadline(nbconn.NonBlockingDeadline) + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.ErrorIs(t, err, nbconn.ErrWouldBlock) + require.EqualValues(t, 0, n) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + err = conn.SetReadDeadline(time.Time{}) + require.NoError(t, err) + + n, err = conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + +func TestReadPreviouslyBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 5) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf) + }) +} + +func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 10) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf[:n]) + }) +} + +func TestReadPreviouslyBufferedPartialRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 2) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.Equal(t, []byte("al"), readBuf) + + readBuf = make([]byte, 3) + n, err = conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 3, n) + require.Equal(t, []byte("pha"), readBuf) + }) +} + +func TestReadMultiplePreviouslyBuffered(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 9) + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) + }) +} + +func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + + flushCompleteChan := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + <-flushCompleteChan + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flush must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + close(flushCompleteChan) + + readBuf := make([]byte, 9) + + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) + + err = <-errChan + require.NoError(t, err) + }) +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index c8b41f84..002db39a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -13,9 +13,10 @@ import ( "net" "strconv" "strings" - "sync" "time" + "github.com/jackc/pgx/v5/internal/iobufpool" + "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" @@ -75,11 +76,6 @@ type PgConn struct { status byte // One of connStatus* constants - bufferingReceive bool - bufferingReceiveMux sync.Mutex - bufferingReceiveMsg pgproto3.BackendMessage - bufferingReceiveErr error - peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -234,13 +230,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } return nil, &connectError{config: config, msg: "dial error", err: err} } + netConn = nbconn.NewNetConn(netConn, false) pgConn.conn = netConn pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -356,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { +func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -371,7 +368,12 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { return nil, errors.New("server refused TLS connection") } - return tls.Client(conn, tlsConfig), nil + tlsConn, err := nbconn.TLSClient(conn, tlsConfig) + if err != nil { + return nil, err + } + + return tlsConn, nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { @@ -385,24 +387,6 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (pgConn *PgConn) signalMessage() chan struct{} { - if pgConn.bufferingReceive { - panic("BUG: signalMessage when already in progress") - } - - pgConn.bufferingReceive = true - pgConn.bufferingReceiveMux.Lock() - - ch := make(chan struct{}) - go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() - pgConn.bufferingReceiveMux.Unlock() - close(ch) - }() - - return ch -} - // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -442,25 +426,13 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - var msg pgproto3.BackendMessage - var err error - if pgConn.bufferingReceive { - pgConn.bufferingReceiveMux.Lock() - msg = pgConn.bufferingReceiveMsg - err = pgConn.bufferingReceiveErr - pgConn.bufferingReceiveMux.Unlock() - pgConn.bufferingReceive = false - - // If a timeout error happened in the background try the read again. - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - msg, err = pgConn.frontend.Receive() - } - } else { - msg, err = pgConn.frontend.Receive() - } + msg, err := pgConn.frontend.Receive() if err != nil { + if errors.Is(err, nbconn.ErrWouldBlock) { + return nil, err + } + // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -479,13 +451,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { - // Close on anything other than timeout error - everything else is fatal - var netErr net.Error - isNetErr := errors.As(err, &netErr) - if !(isNetErr && netErr.Timeout()) { - pgConn.asyncClose() - } - return nil, err } pgConn.peekedMsg = nil @@ -1173,62 +1138,58 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - // Send copy data - abortCopyChan := make(chan struct{}) - copyErrChan := make(chan error, 1) - signalMessageChan := pgConn.signalMessage() - senderDoneChan := make(chan struct{}) - - go func() { - defer close(senderDoneChan) - - buf := make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - - for { - n, readErr := r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) - if writeErr != nil { - // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. - pgConn.conn.Close() - - copyErrChan <- writeErr - return - } - } - if readErr != nil { - copyErrChan <- readErr - return - } - - select { - case <-abortCopyChan: - return - default: - } + err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking := true + defer func() { + if nonblocking { + pgConn.conn.SetReadDeadline(time.Time{}) } }() - var pgErr error - var copyErr error - for copyErr == nil && pgErr == nil { - select { - case copyErr = <-copyErrChan: - case <-signalMessageChan: + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + buf[0] = 'd' + + var readErr, pgErr error + for pgErr == nil { + // Read chunk from r. + var n int + n, readErr = r.Read(buf[5:cap(buf)]) + + // Send chunk to PostgreSQL. + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) + if writeErr != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + } + + // Abort loop if there was a read error. + if readErr != nil { + break + } + + // Read messages until error or none available. + for pgErr == nil { msg, err := pgConn.receiveMessage() if err != nil { + if errors.Is(err, nbconn.ErrWouldBlock) { + break + } pgConn.asyncClose() return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } @@ -1236,18 +1197,22 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - default: - signalMessageChan = pgConn.signalMessage() + break } } } - close(abortCopyChan) - <-senderDoneChan - if copyErr == io.EOF || pgErr != nil { + err = pgConn.conn.SetReadDeadline(time.Time{}) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking = false + + if readErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) } err = pgConn.frontend.Flush() if err != nil { @@ -1603,18 +1568,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is - // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication - // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. - // The error the code reading the batch results receives will be a closed connection error. - // - // See https://github.com/jackc/pgx/issues/374. - go func() { - _, err := pgConn.conn.Write(batch.buf) - if err != nil { - pgConn.conn.Close() - } - }() + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + multiResult.closed = true + multiResult.err = err + pgConn.unlock() + return multiResult + } return multiResult } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index fdce6e7d..07b68995 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1849,13 +1849,14 @@ func TestConnCancelRequest(t *testing.T) { multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + go func() { + // The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent. + // Once Flush is available this could use that instead. + time.Sleep(500 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) - require.NoError(t, err) + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + }() for multiResult.NextResult() { } @@ -2027,6 +2028,36 @@ func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { require.Error(t, err) } +// https://github.com/jackc/pgconn/issues/27 +func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "set client_min_messages = debug5").ReadAll() + require.NoError(t, err) + + // The actual contents of this test aren't important. What's important is a large amount of data to be written and + // because of client_min_messages = debug5 the server will return a large amount of data. + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {