diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 6a92e8fb..0ceb3c79 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -1,4 +1,13 @@ // Package nbconn implements a non-blocking net.Conn wrapper. +// +// It is designed to solve three problems. +// +// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all +// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. +// +// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. +// +// The third is to efficiently check if a connection has been closed via a non-blocking read. package nbconn import ( @@ -14,26 +23,38 @@ import ( ) var errClosed = errors.New("closed") -var ErrWouldBlock = errors.New("would block") +var ErrWouldBlock = new(wouldBlockError) const fakeNonblockingWaitDuration = 100 * time.Millisecond -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) +// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read +// mode. +var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) -// Conn is a non-blocking net.Conn wrapper. It implements net.Conn. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -type Conn struct { - netConn net.Conn - tlsConn *tls.Conn - maybeTLSConn net.Conn +// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to +// ignore all future calls. +var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) + +// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. +type wouldBlockError struct{} + +func (*wouldBlockError) Error() string { + return "would block" +} + +func (*wouldBlockError) Timeout() bool { return true } +func (*wouldBlockError) Temporary() bool { return true } + +// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the +// underlying connection. +type Conn interface { + net.Conn + Flush() error +} + +// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. +type NetConn struct { + conn net.Conn readQueue bufferQueue writeQueue bufferQueue @@ -51,21 +72,14 @@ type Conn struct { closed int64 // 0 = not closed, 1 = closed } -func New(conn net.Conn) *Conn { - return &Conn{ - netConn: conn, - maybeTLSConn: conn, +func NewNetConn(conn net.Conn) *NetConn { + return &NetConn{ + conn: conn, } } -// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once. -func (c *Conn) StartTLS(config *tls.Config) { - c.tlsConn = tls.Client(c.netConn, config) - c.maybeTLSConn = c.tlsConn -} - // Read implements io.Reader. -func (c *Conn) Read(b []byte) (n int, err error) { +func (c *NetConn) Read(b []byte) (n int, err error) { if c.isClosed() { return 0, errClosed } @@ -106,7 +120,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { if readNonblocking { readN, err = c.nonblockingRead(b[n:]) } else { - readN, err = c.maybeTLSConn.Read(b[n:]) + readN, err = c.conn.Read(b[n:]) } n += readN return n, err @@ -114,7 +128,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is // closed. Call Flush to actually write to the underlying connection. -func (c *Conn) Write(b []byte) (n int, err error) { +func (c *NetConn) Write(b []byte) (n int, err error) { if c.isClosed() { return 0, errClosed } @@ -125,14 +139,14 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } -func (c *Conn) Close() (err error) { +func (c *NetConn) Close() (err error) { swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) if !swapped { return errClosed } defer func() { - closeErr := c.maybeTLSConn.Close() + closeErr := c.conn.Close() if err == nil { err = closeErr } @@ -148,16 +162,16 @@ func (c *Conn) Close() (err error) { return nil } -func (c *Conn) LocalAddr() net.Addr { - return c.maybeTLSConn.LocalAddr() +func (c *NetConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() } -func (c *Conn) RemoteAddr() net.Addr { - return c.maybeTLSConn.RemoteAddr() +func (c *NetConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() } // SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *Conn) SetDeadline(t time.Time) error { +func (c *NetConn) SetDeadline(t time.Time) error { err := c.SetReadDeadline(t) if err != nil { return err @@ -166,13 +180,20 @@ func (c *Conn) SetDeadline(t time.Time) error { } // SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *Conn) SetReadDeadline(t time.Time) error { +func (c *NetConn) SetReadDeadline(t time.Time) error { if c.isClosed() { return errClosed } c.readDeadlineLock.Lock() defer c.readDeadlineLock.Unlock() + if c.readDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.readDeadline = t + return nil + } if t == NonBlockingDeadline { c.readNonblocking = true @@ -183,22 +204,30 @@ func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline = t - return c.maybeTLSConn.SetReadDeadline(t) + return c.conn.SetReadDeadline(t) } -func (c *Conn) SetWriteDeadline(t time.Time) error { +func (c *NetConn) SetWriteDeadline(t time.Time) error { if c.isClosed() { return errClosed } c.writeDeadlineLock.Lock() defer c.writeDeadlineLock.Unlock() + if c.writeDeadline == disableSetDeadlineDeadline { + return nil + } + if t == disableSetDeadlineDeadline { + c.writeDeadline = t + return nil + } + c.writeDeadline = t - return c.netConn.SetWriteDeadline(t) + return c.conn.SetWriteDeadline(t) } -func (c *Conn) Flush() error { +func (c *NetConn) Flush() error { if c.isClosed() { return errClosed } @@ -209,7 +238,7 @@ func (c *Conn) Flush() error { } // flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *Conn) flush() error { +func (c *NetConn) flush() error { var stopChan chan struct{} var errChan chan error @@ -255,7 +284,7 @@ func (c *Conn) flush() error { return nil } -func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { +func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { stopChan = make(chan struct{}) errChan = make(chan error, 1) @@ -286,28 +315,28 @@ func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan err return stopChan, errChan } -func (c *Conn) isClosed() bool { +func (c *NetConn) isClosed() bool { closed := atomic.LoadInt64(&c.closed) return closed == 1 } -func (c *Conn) nonblockingWrite(b []byte) (n int, err error) { +func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { return c.fakeNonblockingWrite(b) } -func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) { +func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { c.writeDeadlineLock.Lock() defer c.writeDeadlineLock.Unlock() deadline := time.Now().Add(fakeNonblockingWaitDuration) if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { - err = c.netConn.SetWriteDeadline(deadline) + err = c.conn.SetWriteDeadline(deadline) if err != nil { return 0, err } defer func() { // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.netConn.SetWriteDeadline(c.writeDeadline) + c.conn.SetWriteDeadline(c.writeDeadline) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { @@ -317,26 +346,26 @@ func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) { }() } - return c.netConn.Write(b) + return c.conn.Write(b) } -func (c *Conn) nonblockingRead(b []byte) (n int, err error) { +func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { return c.fakeNonblockingRead(b) } -func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) { +func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { c.readDeadlineLock.Lock() defer c.readDeadlineLock.Unlock() deadline := time.Now().Add(fakeNonblockingWaitDuration) if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { - err = c.netConn.SetReadDeadline(deadline) + err = c.conn.SetReadDeadline(deadline) if err != nil { return 0, err } defer func() { // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.netConn.SetReadDeadline(c.readDeadline) + c.conn.SetReadDeadline(c.readDeadline) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { @@ -346,7 +375,58 @@ func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) { }() } - return c.netConn.Read(b) + return c.conn.Read(b) } // syscall.Conn is interface + +// TLSClient establishes a TLS connection as a client over conn using config. +// +// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby +// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the +// *TLSConn is returned. +func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { + tc := tls.Client(conn, config) + err := tc.Handshake() + if err != nil { + return nil, err + } + + // Ensure last written part of Handshake is actually sent. + err = conn.Flush() + if err != nil { + return nil, err + } + + return &TLSConn{ + tlsConn: tc, + nbConn: conn, + }, nil +} + +// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a +// tls.Conn. +type TLSConn struct { + tlsConn *tls.Conn + nbConn *NetConn +} + +func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } +func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } +func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } +func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } + +func (tc *TLSConn) Close() error { + // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then + // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our + // own 5 second deadline then make all set deadlines no-op. + tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) + tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) + + return tc.tlsConn.Close() +} + +func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } +func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } +func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index d3b56e36..8622fcff 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -65,7 +65,7 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG 2qWm8jTPeDC3sq+67s2oojHf+Q== -----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY")) -func testVariants(t *testing.T, f func(t *testing.T, local *nbconn.Conn, remote net.Conn)) { +func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) { for _, tt := range []struct { name string makeConns func(t *testing.T) (local, remote net.Conn) @@ -89,16 +89,32 @@ func testVariants(t *testing.T, f func(t *testing.T, local *nbconn.Conn, remote } { t.Run(tt.name, func(t *testing.T) { local, remote := tt.makeConns(t) - conn := nbconn.New(local) + + var conn nbconn.Conn + netConn := nbconn.NewNetConn(local) if tt.useTLS { cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) require.NoError(t, err) - remote = tls.Server(remote, &tls.Config{ + tlsServer := tls.Server(remote, &tls.Config{ Certificates: []tls.Certificate{cert}, }) - conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) + serverTLSHandshakeChan := make(chan error) + go func() { + err := tlsServer.Handshake() + serverTLSHandshakeChan <- err + }() + + tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err) + conn = tlsConn + + err = <-serverTLSHandshakeChan + require.NoError(t, err) + remote = tlsServer + } else { + conn = netConn } f(t, conn, remote) @@ -147,7 +163,7 @@ func makeTCPConns(t *testing.T) (local, remote net.Conn) { } func TestWriteIsBuffered(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { // net.Pipe is synchronous so the Write would block if not buffered. writeBuf := []byte("test") n, err := conn.Write(writeBuf) @@ -169,7 +185,7 @@ func TestWriteIsBuffered(t *testing.T) { } func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { err := conn.SetWriteDeadline(time.Now()) require.NoError(t, err) @@ -181,7 +197,7 @@ func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { } func TestReadFlushesWriteBuffer(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { writeBuf := []byte("test") n, err := conn.Write(writeBuf) require.NoError(t, err) @@ -208,7 +224,7 @@ func TestReadFlushesWriteBuffer(t *testing.T) { } func TestCloseFlushesWriteBuffer(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { writeBuf := []byte("test") n, err := conn.Write(writeBuf) require.NoError(t, err) @@ -229,7 +245,7 @@ func TestCloseFlushesWriteBuffer(t *testing.T) { } func TestNonBlockingRead(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { err := conn.SetReadDeadline(nbconn.NonBlockingDeadline) require.NoError(t, err) @@ -254,7 +270,7 @@ func TestNonBlockingRead(t *testing.T) { } func TestReadPreviouslyBuffered(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { errChan := make(chan error, 1) go func() { @@ -291,7 +307,7 @@ func TestReadPreviouslyBuffered(t *testing.T) { } func TestReadPreviouslyBufferedPartialRead(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { errChan := make(chan error, 1) go func() { @@ -334,7 +350,7 @@ func TestReadPreviouslyBufferedPartialRead(t *testing.T) { } func TestReadMultiplePreviouslyBuffered(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { errChan := make(chan error, 1) go func() { err := func() error { @@ -367,7 +383,7 @@ func TestReadMultiplePreviouslyBuffered(t *testing.T) { require.NoError(t, err) readBuf := make([]byte, 9) - n, err := conn.Read(readBuf) + n, err := io.ReadFull(conn, readBuf) require.NoError(t, err) require.EqualValues(t, 9, n) require.Equal(t, []byte("alphabeta"), readBuf) @@ -375,7 +391,7 @@ func TestReadMultiplePreviouslyBuffered(t *testing.T) { } func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { - testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { flushCompleteChan := make(chan struct{}) errChan := make(chan error, 1) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index bac08db0..5a5fac82 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -230,13 +230,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } return nil, &connectError{config: config, msg: "dial error", err: err} } + netConn = nbconn.NewNetConn(netConn) pgConn.conn = netConn pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -248,11 +249,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.contextWatcher.Watch(ctx) } - pgConn.conn = nbconn.New(pgConn.conn) - pgConn.contextWatcher.Unwatch() // context watcher should watch nbconn - pgConn.contextWatcher = newContextWatcher(pgConn.conn) - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) @@ -357,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { +func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -372,7 +368,12 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { return nil, errors.New("server refused TLS connection") } - return tls.Client(conn, tlsConfig), nil + tlsConn, err := nbconn.TLSClient(conn, tlsConfig) + if err != nil { + return nil, err + } + + return tlsConn, nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) {