mirror of https://github.com/jackc/pgx.git
Steps toward TLS
parent
2b80beb1ed
commit
51655bf8f4
|
@ -31,7 +31,9 @@ var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
|
|||
//
|
||||
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||
type Conn struct {
|
||||
netConn net.Conn
|
||||
netConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
maybeTLSConn net.Conn
|
||||
|
||||
readQueue bufferQueue
|
||||
writeQueue bufferQueue
|
||||
|
@ -51,13 +53,15 @@ type Conn struct {
|
|||
|
||||
func New(conn net.Conn) *Conn {
|
||||
return &Conn{
|
||||
netConn: conn,
|
||||
netConn: conn,
|
||||
maybeTLSConn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once.
|
||||
func (c *Conn) StartTLS(config *tls.Config) {
|
||||
c.netConn = tls.Client(c.netConn, config)
|
||||
c.tlsConn = tls.Client(c.netConn, config)
|
||||
c.maybeTLSConn = c.tlsConn
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
|
@ -102,7 +106,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||
if readNonblocking {
|
||||
readN, err = c.nonblockingRead(b[n:])
|
||||
} else {
|
||||
readN, err = c.netConn.Read(b[n:])
|
||||
readN, err = c.maybeTLSConn.Read(b[n:])
|
||||
}
|
||||
n += readN
|
||||
return n, err
|
||||
|
@ -128,7 +132,7 @@ func (c *Conn) Close() (err error) {
|
|||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := c.netConn.Close()
|
||||
closeErr := c.maybeTLSConn.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
|
@ -145,11 +149,11 @@ func (c *Conn) Close() (err error) {
|
|||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.netConn.LocalAddr()
|
||||
return c.maybeTLSConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.netConn.RemoteAddr()
|
||||
return c.maybeTLSConn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||
|
@ -179,7 +183,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
|
|||
|
||||
c.readDeadline = t
|
||||
|
||||
return c.netConn.SetReadDeadline(t)
|
||||
return c.maybeTLSConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
|
|
|
@ -168,6 +168,18 @@ func TestWriteIsBuffered(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetWriteDeadline(time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
|
|
Loading…
Reference in New Issue