mirror of https://github.com/jackc/pgx.git
More TLS support
parent
51655bf8f4
commit
afa702213f
|
@ -1,4 +1,13 @@
|
|||
// Package nbconn implements a non-blocking net.Conn wrapper.
|
||||
//
|
||||
// It is designed to solve three problems.
|
||||
//
|
||||
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
|
||||
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
|
||||
//
|
||||
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
|
||||
//
|
||||
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||
package nbconn
|
||||
|
||||
import (
|
||||
|
@ -14,26 +23,38 @@ import (
|
|||
)
|
||||
|
||||
var errClosed = errors.New("closed")
|
||||
var ErrWouldBlock = errors.New("would block")
|
||||
var ErrWouldBlock = new(wouldBlockError)
|
||||
|
||||
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
||||
|
||||
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
|
||||
// mode.
|
||||
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
|
||||
|
||||
// Conn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
//
|
||||
// It is designed to solve three problems.
|
||||
//
|
||||
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
|
||||
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
|
||||
//
|
||||
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
|
||||
//
|
||||
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||
type Conn struct {
|
||||
netConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
maybeTLSConn net.Conn
|
||||
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
|
||||
// ignore all future calls.
|
||||
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
|
||||
|
||||
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
|
||||
type wouldBlockError struct{}
|
||||
|
||||
func (*wouldBlockError) Error() string {
|
||||
return "would block"
|
||||
}
|
||||
|
||||
func (*wouldBlockError) Timeout() bool { return true }
|
||||
func (*wouldBlockError) Temporary() bool { return true }
|
||||
|
||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the
|
||||
// underlying connection.
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
type NetConn struct {
|
||||
conn net.Conn
|
||||
|
||||
readQueue bufferQueue
|
||||
writeQueue bufferQueue
|
||||
|
@ -51,21 +72,14 @@ type Conn struct {
|
|||
closed int64 // 0 = not closed, 1 = closed
|
||||
}
|
||||
|
||||
func New(conn net.Conn) *Conn {
|
||||
return &Conn{
|
||||
netConn: conn,
|
||||
maybeTLSConn: conn,
|
||||
func NewNetConn(conn net.Conn) *NetConn {
|
||||
return &NetConn{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once.
|
||||
func (c *Conn) StartTLS(config *tls.Config) {
|
||||
c.tlsConn = tls.Client(c.netConn, config)
|
||||
c.maybeTLSConn = c.tlsConn
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
func (c *NetConn) Read(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
@ -106,7 +120,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||
if readNonblocking {
|
||||
readN, err = c.nonblockingRead(b[n:])
|
||||
} else {
|
||||
readN, err = c.maybeTLSConn.Read(b[n:])
|
||||
readN, err = c.conn.Read(b[n:])
|
||||
}
|
||||
n += readN
|
||||
return n, err
|
||||
|
@ -114,7 +128,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||
|
||||
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
||||
// closed. Call Flush to actually write to the underlying connection.
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
func (c *NetConn) Write(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
@ -125,14 +139,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() (err error) {
|
||||
func (c *NetConn) Close() (err error) {
|
||||
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
|
||||
if !swapped {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := c.maybeTLSConn.Close()
|
||||
closeErr := c.conn.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
|
@ -148,16 +162,16 @@ func (c *Conn) Close() (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.maybeTLSConn.LocalAddr()
|
||||
func (c *NetConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.maybeTLSConn.RemoteAddr()
|
||||
func (c *NetConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
func (c *NetConn) SetDeadline(t time.Time) error {
|
||||
err := c.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -166,13 +180,20 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
|||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
func (c *NetConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
if c.readDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.readDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
if t == NonBlockingDeadline {
|
||||
c.readNonblocking = true
|
||||
|
@ -183,22 +204,30 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
|
|||
|
||||
c.readDeadline = t
|
||||
|
||||
return c.maybeTLSConn.SetReadDeadline(t)
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
func (c *NetConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
if c.writeDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.writeDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
c.writeDeadline = t
|
||||
|
||||
return c.netConn.SetWriteDeadline(t)
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) Flush() error {
|
||||
func (c *NetConn) Flush() error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
@ -209,7 +238,7 @@ func (c *Conn) Flush() error {
|
|||
}
|
||||
|
||||
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
|
||||
func (c *Conn) flush() error {
|
||||
func (c *NetConn) flush() error {
|
||||
var stopChan chan struct{}
|
||||
var errChan chan error
|
||||
|
||||
|
@ -255,7 +284,7 @@ func (c *Conn) flush() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||
stopChan = make(chan struct{})
|
||||
errChan = make(chan error, 1)
|
||||
|
||||
|
@ -286,28 +315,28 @@ func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan err
|
|||
return stopChan, errChan
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
func (c *NetConn) isClosed() bool {
|
||||
closed := atomic.LoadInt64(&c.closed)
|
||||
return closed == 1
|
||||
}
|
||||
|
||||
func (c *Conn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
}
|
||||
|
||||
func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
|
||||
err = c.netConn.SetWriteDeadline(deadline)
|
||||
err = c.conn.SetWriteDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.netConn.SetWriteDeadline(c.writeDeadline)
|
||||
c.conn.SetWriteDeadline(c.writeDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
|
@ -317,26 +346,26 @@ func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
|||
}()
|
||||
}
|
||||
|
||||
return c.netConn.Write(b)
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) nonblockingRead(b []byte) (n int, err error) {
|
||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||
return c.fakeNonblockingRead(b)
|
||||
}
|
||||
|
||||
func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
|
||||
err = c.netConn.SetReadDeadline(deadline)
|
||||
err = c.conn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.netConn.SetReadDeadline(c.readDeadline)
|
||||
c.conn.SetReadDeadline(c.readDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
|
@ -346,7 +375,58 @@ func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {
|
|||
}()
|
||||
}
|
||||
|
||||
return c.netConn.Read(b)
|
||||
return c.conn.Read(b)
|
||||
}
|
||||
|
||||
// syscall.Conn is interface
|
||||
|
||||
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||
//
|
||||
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
|
||||
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
|
||||
// *TLSConn is returned.
|
||||
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
|
||||
tc := tls.Client(conn, config)
|
||||
err := tc.Handshake()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure last written part of Handshake is actually sent.
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TLSConn{
|
||||
tlsConn: tc,
|
||||
nbConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
|
||||
// tls.Conn.
|
||||
type TLSConn struct {
|
||||
tlsConn *tls.Conn
|
||||
nbConn *NetConn
|
||||
}
|
||||
|
||||
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||
|
||||
func (tc *TLSConn) Close() error {
|
||||
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
|
||||
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
|
||||
// own 5 second deadline then make all set deadlines no-op.
|
||||
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
|
||||
|
||||
return tc.tlsConn.Close()
|
||||
}
|
||||
|
||||
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
|
||||
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
|
||||
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }
|
||||
|
|
|
@ -65,7 +65,7 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
|||
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
||||
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
||||
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local *nbconn.Conn, remote net.Conn)) {
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||
|
@ -89,16 +89,32 @@ func testVariants(t *testing.T, f func(t *testing.T, local *nbconn.Conn, remote
|
|||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
local, remote := tt.makeConns(t)
|
||||
conn := nbconn.New(local)
|
||||
|
||||
var conn nbconn.Conn
|
||||
netConn := nbconn.NewNetConn(local)
|
||||
|
||||
if tt.useTLS {
|
||||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
remote = tls.Server(remote, &tls.Config{
|
||||
tlsServer := tls.Server(remote, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||
serverTLSHandshakeChan := make(chan error)
|
||||
go func() {
|
||||
err := tlsServer.Handshake()
|
||||
serverTLSHandshakeChan <- err
|
||||
}()
|
||||
|
||||
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
|
||||
require.NoError(t, err)
|
||||
conn = tlsConn
|
||||
|
||||
err = <-serverTLSHandshakeChan
|
||||
require.NoError(t, err)
|
||||
remote = tlsServer
|
||||
} else {
|
||||
conn = netConn
|
||||
}
|
||||
|
||||
f(t, conn, remote)
|
||||
|
@ -147,7 +163,7 @@ func makeTCPConns(t *testing.T) (local, remote net.Conn) {
|
|||
}
|
||||
|
||||
func TestWriteIsBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
// net.Pipe is synchronous so the Write would block if not buffered.
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
|
@ -169,7 +185,7 @@ func TestWriteIsBuffered(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetWriteDeadline(time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -181,7 +197,7 @@ func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
|
@ -208,7 +224,7 @@ func TestReadFlushesWriteBuffer(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
|
@ -229,7 +245,7 @@ func TestCloseFlushesWriteBuffer(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -254,7 +270,7 @@ func TestNonBlockingRead(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
|
@ -291,7 +307,7 @@ func TestReadPreviouslyBuffered(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
|
@ -334,7 +350,7 @@ func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
|
@ -367,7 +383,7 @@ func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
n, err := conn.Read(readBuf)
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
|
@ -375,7 +391,7 @@ func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
flushCompleteChan := make(chan struct{})
|
||||
errChan := make(chan error, 1)
|
||||
|
|
|
@ -230,13 +230,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
}
|
||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||
}
|
||||
netConn = nbconn.NewNetConn(netConn)
|
||||
|
||||
pgConn.conn = netConn
|
||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
||||
if fallbackConfig.TLSConfig != nil {
|
||||
tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
|
||||
tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig)
|
||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||
if err != nil {
|
||||
netConn.Close()
|
||||
|
@ -248,11 +249,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
pgConn.contextWatcher.Watch(ctx)
|
||||
}
|
||||
|
||||
pgConn.conn = nbconn.New(pgConn.conn)
|
||||
pgConn.contextWatcher.Unwatch() // context watcher should watch nbconn
|
||||
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
|
@ -357,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
|||
)
|
||||
}
|
||||
|
||||
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -372,7 +368,12 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
|||
return nil, errors.New("server refused TLS connection")
|
||||
}
|
||||
|
||||
return tls.Client(conn, tlsConfig), nil
|
||||
tlsConn, err := nbconn.TLSClient(conn, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||
|
|
Loading…
Reference in New Issue