Add fallback TLS ConnConfig option

This is in preparation for supporting libpq style SSL options.
pull/78/head
Jack Christensen 2015-05-23 11:57:36 -05:00
parent 0c0a426d18
commit dd9d960ba3
2 changed files with 66 additions and 20 deletions

40
conn.go
View File

@ -30,6 +30,8 @@ type ConnConfig struct {
User string // default: OS user name
Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
Logger Logger
Dial DialFunc
}
@ -140,11 +142,25 @@ func Connect(config ConnConfig) (c *Conn, err error) {
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
}
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
err = c.connect(config, network, address, config.FallbackTLSConfig)
}
if err != nil {
return nil, err
}
return c, nil
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
c.conn, err = c.config.Dial(network, address)
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
return err
}
defer func() {
if c != nil && err != nil {
@ -159,11 +175,11 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.alive = true
c.lastActivityTime = time.Now()
if config.TLSConfig != nil {
if tlsConfig != nil {
c.logger.Debug("Starting TLS handshake")
if err = c.startTLS(); err != nil {
if err := c.startTLS(tlsConfig); err != nil {
c.logger.Error(fmt.Sprintf("TLS failed: %v", err))
return
return err
}
}
@ -176,7 +192,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
msg.options["database"] = c.config.Database
}
if err = c.txStartupMessage(msg); err != nil {
return
return err
}
for {
@ -184,7 +200,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
var r *msgReader
t, r, err = c.rxMsg()
if err != nil {
return nil, err
return err
}
switch t {
@ -192,7 +208,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.rxBackendKeyData(r)
case authenticationX:
if err = c.rxAuthenticationX(r); err != nil {
return nil, err
return err
}
case readyForQuery:
c.rxReadyForQuery(r)
@ -203,13 +219,13 @@ func Connect(config ConnConfig) (c *Conn, err error) {
err = c.loadPgTypes()
if err != nil {
return nil, err
return err
}
return c, nil
return nil
default:
if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err
return err
}
}
}
@ -905,7 +921,7 @@ func (c *Conn) rxNotificationResponse(r *msgReader) {
c.notifications = append(c.notifications, n)
}
func (c *Conn) startTLS() (err error) {
func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return
@ -920,7 +936,7 @@ func (c *Conn) startTLS() (err error) {
return ErrTLSRefused
}
c.conn = tls.Client(c.conn, c.config.TLSConfig)
c.conn = tls.Client(c.conn, tlsConfig)
return nil
}

View File

@ -1,6 +1,7 @@
package pgx_test
import (
"crypto/tls"
"fmt"
"github.com/jackc/pgx"
"net"
@ -184,6 +185,35 @@ func TestConnectWithMD5Password(t *testing.T) {
}
}
func TestConnectWithTLSFallback(t *testing.T) {
t.Parallel()
if tlsConnConfig == nil {
return
}
connConfig := *tlsConnConfig
connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure
conn, err := pgx.Connect(connConfig)
if err == nil {
t.Fatal("Expected failed connection, but succeeded")
}
connConfig.UseFallbackTLS = true
connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
conn, err = pgx.Connect(connConfig)
if err != nil {
t.Fatal("Unable to establish connection: " + err.Error())
}
err = conn.Close()
if err != nil {
t.Fatal("Unable to close connection")
}
}
func TestConnectWithConnectionRefused(t *testing.T) {
t.Parallel()