Steps toward TLS

non-blocking
Jack Christensen 2022-06-04 17:46:00 -05:00
parent 2b80beb1ed
commit 51655bf8f4
2 changed files with 24 additions and 8 deletions

View File

@ -31,7 +31,9 @@ var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
type Conn struct {
netConn net.Conn
netConn net.Conn
tlsConn *tls.Conn
maybeTLSConn net.Conn
readQueue bufferQueue
writeQueue bufferQueue
@ -51,13 +53,15 @@ type Conn struct {
func New(conn net.Conn) *Conn {
return &Conn{
netConn: conn,
netConn: conn,
maybeTLSConn: conn,
}
}
// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once.
func (c *Conn) StartTLS(config *tls.Config) {
c.netConn = tls.Client(c.netConn, config)
c.tlsConn = tls.Client(c.netConn, config)
c.maybeTLSConn = c.tlsConn
}
// Read implements io.Reader.
@ -102,7 +106,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
if readNonblocking {
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.netConn.Read(b[n:])
readN, err = c.maybeTLSConn.Read(b[n:])
}
n += readN
return n, err
@ -128,7 +132,7 @@ func (c *Conn) Close() (err error) {
}
defer func() {
closeErr := c.netConn.Close()
closeErr := c.maybeTLSConn.Close()
if err == nil {
err = closeErr
}
@ -145,11 +149,11 @@ func (c *Conn) Close() (err error) {
}
func (c *Conn) LocalAddr() net.Addr {
return c.netConn.LocalAddr()
return c.maybeTLSConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.netConn.RemoteAddr()
return c.maybeTLSConn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
@ -179,7 +183,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadline = t
return c.netConn.SetReadDeadline(t)
return c.maybeTLSConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {

View File

@ -168,6 +168,18 @@ func TestWriteIsBuffered(t *testing.T) {
})
}
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
err := conn.SetWriteDeadline(time.Now())
require.NoError(t, err)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestReadFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")