diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 5051f52b..6a92e8fb 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -31,7 +31,9 @@ var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) // // The third is to efficiently check if a connection has been closed via a non-blocking read. type Conn struct { - netConn net.Conn + netConn net.Conn + tlsConn *tls.Conn + maybeTLSConn net.Conn readQueue bufferQueue writeQueue bufferQueue @@ -51,13 +53,15 @@ type Conn struct { func New(conn net.Conn) *Conn { return &Conn{ - netConn: conn, + netConn: conn, + maybeTLSConn: conn, } } // StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once. func (c *Conn) StartTLS(config *tls.Config) { - c.netConn = tls.Client(c.netConn, config) + c.tlsConn = tls.Client(c.netConn, config) + c.maybeTLSConn = c.tlsConn } // Read implements io.Reader. @@ -102,7 +106,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { if readNonblocking { readN, err = c.nonblockingRead(b[n:]) } else { - readN, err = c.netConn.Read(b[n:]) + readN, err = c.maybeTLSConn.Read(b[n:]) } n += readN return n, err @@ -128,7 +132,7 @@ func (c *Conn) Close() (err error) { } defer func() { - closeErr := c.netConn.Close() + closeErr := c.maybeTLSConn.Close() if err == nil { err = closeErr } @@ -145,11 +149,11 @@ func (c *Conn) Close() (err error) { } func (c *Conn) LocalAddr() net.Addr { - return c.netConn.LocalAddr() + return c.maybeTLSConn.LocalAddr() } func (c *Conn) RemoteAddr() net.Addr { - return c.netConn.RemoteAddr() + return c.maybeTLSConn.RemoteAddr() } // SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). @@ -179,7 +183,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline = t - return c.netConn.SetReadDeadline(t) + return c.maybeTLSConn.SetReadDeadline(t) } func (c *Conn) SetWriteDeadline(t time.Time) error { diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index f99258e0..d3b56e36 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -168,6 +168,18 @@ func TestWriteIsBuffered(t *testing.T) { }) } +func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { + testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { + err := conn.SetWriteDeadline(time.Now()) + require.NoError(t, err) + + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) +} + func TestReadFlushesWriteBuffer(t *testing.T) { testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { writeBuf := []byte("test")