From dd9d960ba37478e4201697fe46add55c58a69208 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 May 2015 11:57:36 -0500 Subject: [PATCH] Add fallback TLS ConnConfig option This is in preparation for supporting libpq style SSL options. --- conn.go | 56 +++++++++++++++++++++++++++++++++------------------- conn_test.go | 30 ++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index 9ac9d184..c9937f6e 100644 --- a/conn.go +++ b/conn.go @@ -24,14 +24,16 @@ type DialFunc func(network, addr string) (net.Conn, error) // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - Logger Logger - Dial DialFunc + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 // default: 5432 + Database string + User string // default: OS user name + Password string + TLSConfig *tls.Config // config for TLS connection -- nil disables TLS + UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa + FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS + Logger Logger + Dial DialFunc } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -140,11 +142,25 @@ func Connect(config ConnConfig) (c *Conn, err error) { if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } + + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + err = c.connect(config, network, address, config.FallbackTLSConfig) + } + + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) c.conn, err = c.config.Dial(network, address) if err != nil { c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) - return nil, err + return err } defer func() { if c != nil && err != nil { @@ -159,11 +175,11 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.alive = true c.lastActivityTime = time.Now() - if config.TLSConfig != nil { + if tlsConfig != nil { c.logger.Debug("Starting TLS handshake") - if err = c.startTLS(); err != nil { + if err := c.startTLS(tlsConfig); err != nil { c.logger.Error(fmt.Sprintf("TLS failed: %v", err)) - return + return err } } @@ -176,7 +192,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { msg.options["database"] = c.config.Database } if err = c.txStartupMessage(msg); err != nil { - return + return err } for { @@ -184,7 +200,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { var r *msgReader t, r, err = c.rxMsg() if err != nil { - return nil, err + return err } switch t { @@ -192,7 +208,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.rxBackendKeyData(r) case authenticationX: if err = c.rxAuthenticationX(r); err != nil { - return nil, err + return err } case readyForQuery: c.rxReadyForQuery(r) @@ -203,13 +219,13 @@ func Connect(config ConnConfig) (c *Conn, err error) { err = c.loadPgTypes() if err != nil { - return nil, err + return err } - return c, nil + return nil default: if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err + return err } } } @@ -905,7 +921,7 @@ func (c *Conn) rxNotificationResponse(r *msgReader) { c.notifications = append(c.notifications, n) } -func (c *Conn) startTLS() (err error) { +func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return @@ -920,7 +936,7 @@ func (c *Conn) startTLS() (err error) { return ErrTLSRefused } - c.conn = tls.Client(c.conn, c.config.TLSConfig) + c.conn = tls.Client(c.conn, tlsConfig) return nil } diff --git a/conn_test.go b/conn_test.go index 72d8174b..2a89c3a7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "crypto/tls" "fmt" "github.com/jackc/pgx" "net" @@ -184,6 +185,35 @@ func TestConnectWithMD5Password(t *testing.T) { } } +func TestConnectWithTLSFallback(t *testing.T) { + t.Parallel() + + if tlsConnConfig == nil { + return + } + + connConfig := *tlsConnConfig + connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure + + conn, err := pgx.Connect(connConfig) + if err == nil { + t.Fatal("Expected failed connection, but succeeded") + } + + connConfig.UseFallbackTLS = true + connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + + conn, err = pgx.Connect(connConfig) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel()