Steps toward TLS

non-blocking
Jack Christensen 2022-06-04 17:46:00 -05:00
parent 2b80beb1ed
commit 51655bf8f4
2 changed files with 24 additions and 8 deletions

View File

@ -32,6 +32,8 @@ var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
// The third is to efficiently check if a connection has been closed via a non-blocking read. // The third is to efficiently check if a connection has been closed via a non-blocking read.
type Conn struct { type Conn struct {
netConn net.Conn netConn net.Conn
tlsConn *tls.Conn
maybeTLSConn net.Conn
readQueue bufferQueue readQueue bufferQueue
writeQueue bufferQueue writeQueue bufferQueue
@ -52,12 +54,14 @@ type Conn struct {
func New(conn net.Conn) *Conn { func New(conn net.Conn) *Conn {
return &Conn{ return &Conn{
netConn: conn, netConn: conn,
maybeTLSConn: conn,
} }
} }
// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once. // StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once.
func (c *Conn) StartTLS(config *tls.Config) { func (c *Conn) StartTLS(config *tls.Config) {
c.netConn = tls.Client(c.netConn, config) c.tlsConn = tls.Client(c.netConn, config)
c.maybeTLSConn = c.tlsConn
} }
// Read implements io.Reader. // Read implements io.Reader.
@ -102,7 +106,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
if readNonblocking { if readNonblocking {
readN, err = c.nonblockingRead(b[n:]) readN, err = c.nonblockingRead(b[n:])
} else { } else {
readN, err = c.netConn.Read(b[n:]) readN, err = c.maybeTLSConn.Read(b[n:])
} }
n += readN n += readN
return n, err return n, err
@ -128,7 +132,7 @@ func (c *Conn) Close() (err error) {
} }
defer func() { defer func() {
closeErr := c.netConn.Close() closeErr := c.maybeTLSConn.Close()
if err == nil { if err == nil {
err = closeErr err = closeErr
} }
@ -145,11 +149,11 @@ func (c *Conn) Close() (err error) {
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {
return c.netConn.LocalAddr() return c.maybeTLSConn.LocalAddr()
} }
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.netConn.RemoteAddr() return c.maybeTLSConn.RemoteAddr()
} }
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). // SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
@ -179,7 +183,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadline = t c.readDeadline = t
return c.netConn.SetReadDeadline(t) return c.maybeTLSConn.SetReadDeadline(t)
} }
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {

View File

@ -168,6 +168,18 @@ func TestWriteIsBuffered(t *testing.T) {
}) })
} }
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
err := conn.SetWriteDeadline(time.Now())
require.NoError(t, err)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestReadFlushesWriteBuffer(t *testing.T) { func TestReadFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test") writeBuf := []byte("test")