From 2b80beb1ed75f0f58db8188b87753dbc26b62098 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Jun 2022 13:06:29 -0500 Subject: [PATCH] Litle more TLS support --- internal/nbconn/nbconn.go | 7 ++++ internal/nbconn/nbconn_test.go | 61 ++++++++++++---------------------- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index bb27d0ec..5051f52b 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -2,6 +2,7 @@ package nbconn import ( + "crypto/tls" "errors" "net" "os" @@ -54,6 +55,12 @@ func New(conn net.Conn) *Conn { } } +// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once. +func (c *Conn) StartTLS(config *tls.Config) { + c.netConn = tls.Client(c.netConn, config) +} + +// Read implements io.Reader. func (c *Conn) Read(b []byte) (n int, err error) { if c.isClosed() { return 0, errClosed diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index f05117d5..f99258e0 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -2,6 +2,7 @@ package nbconn_test import ( "crypto/tls" + "io" "net" "strings" "testing" @@ -68,24 +69,38 @@ func testVariants(t *testing.T, f func(t *testing.T, local *nbconn.Conn, remote for _, tt := range []struct { name string makeConns func(t *testing.T) (local, remote net.Conn) + useTLS bool }{ { name: "Pipe", makeConns: makePipeConns, + useTLS: false, }, { name: "TCP", makeConns: makeTCPConns, + useTLS: false, }, { name: "TLS over TCP", - makeConns: makeTLSOverTCPConns, + makeConns: makeTCPConns, + useTLS: true, }, } { t.Run(tt.name, func(t *testing.T) { local, remote := tt.makeConns(t) - conn := nbconn.New(local) + + if tt.useTLS { + cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) + require.NoError(t, err) + + remote = tls.Server(remote, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) + } + f(t, conn, remote) }) } @@ -131,42 +146,6 @@ func makeTCPConns(t *testing.T) (local, remote net.Conn) { return local, remote } -// makeTLSOverTCPConns returns a connected pair of net.Conns running over TCP on localhost with TLS encryption. -func makeTLSOverTCPConns(t *testing.T) (local, remote net.Conn) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer ln.Close() - - type acceptResultT struct { - conn net.Conn - err error - } - acceptChan := make(chan acceptResultT) - - go func() { - conn, err := ln.Accept() - acceptChan <- acceptResultT{conn: conn, err: err} - }() - - localConn, err := net.Dial("tcp", ln.Addr().String()) - require.NoError(t, err) - - acceptResult := <-acceptChan - require.NoError(t, acceptResult.err) - - remoteConn := acceptResult.conn - - cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) - require.NoError(t, err) - - localTLS := tls.Client(localConn, &tls.Config{InsecureSkipVerify: true}) - remoteTLS := tls.Server(remoteConn, &tls.Config{ - Certificates: []tls.Certificate{cert}, - }) - - return localTLS, remoteTLS -} - func TestWriteIsBuffered(t *testing.T) { testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) { // net.Pipe is synchronous so the Write would block if not buffered. @@ -423,9 +402,13 @@ func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { close(flushCompleteChan) readBuf := make([]byte, 9) - n, err := conn.Read(readBuf) + + n, err := io.ReadFull(conn, readBuf) require.NoError(t, err) require.EqualValues(t, 9, n) require.Equal(t, []byte("alphabeta"), readBuf) + + err = <-errChan + require.NoError(t, err) }) }