Add fallback TLS ConnConfig option

This is in preparation for supporting libpq style SSL options.
This commit is contained in:
Jack Christensen 2015-05-23 11:57:36 -05:00
parent 0c0a426d18
commit dd9d960ba3
2 changed files with 66 additions and 20 deletions

56
conn.go
View File

@ -24,14 +24,16 @@ type DialFunc func(network, addr string) (net.Conn, error)
// ConnConfig contains all the options used to establish a connection. // ConnConfig contains all the options used to establish a connection.
type ConnConfig struct { type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16 // default: 5432 Port uint16 // default: 5432
Database string Database string
User string // default: OS user name User string // default: OS user name
Password string Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
Logger Logger UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
Dial DialFunc FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
Logger Logger
Dial DialFunc
} }
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@ -140,11 +142,25 @@ func Connect(config ConnConfig) (c *Conn, err error) {
if c.config.Dial == nil { if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
} }
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
err = c.connect(config, network, address, config.FallbackTLSConfig)
}
if err != nil {
return nil, err
}
return c, nil
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
c.conn, err = c.config.Dial(network, address) c.conn, err = c.config.Dial(network, address)
if err != nil { if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err return err
} }
defer func() { defer func() {
if c != nil && err != nil { if c != nil && err != nil {
@ -159,11 +175,11 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.alive = true c.alive = true
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
if config.TLSConfig != nil { if tlsConfig != nil {
c.logger.Debug("Starting TLS handshake") c.logger.Debug("Starting TLS handshake")
if err = c.startTLS(); err != nil { if err := c.startTLS(tlsConfig); err != nil {
c.logger.Error(fmt.Sprintf("TLS failed: %v", err)) c.logger.Error(fmt.Sprintf("TLS failed: %v", err))
return return err
} }
} }
@ -176,7 +192,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
msg.options["database"] = c.config.Database msg.options["database"] = c.config.Database
} }
if err = c.txStartupMessage(msg); err != nil { if err = c.txStartupMessage(msg); err != nil {
return return err
} }
for { for {
@ -184,7 +200,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
var r *msgReader var r *msgReader
t, r, err = c.rxMsg() t, r, err = c.rxMsg()
if err != nil { if err != nil {
return nil, err return err
} }
switch t { switch t {
@ -192,7 +208,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.rxBackendKeyData(r) c.rxBackendKeyData(r)
case authenticationX: case authenticationX:
if err = c.rxAuthenticationX(r); err != nil { if err = c.rxAuthenticationX(r); err != nil {
return nil, err return err
} }
case readyForQuery: case readyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(r)
@ -203,13 +219,13 @@ func Connect(config ConnConfig) (c *Conn, err error) {
err = c.loadPgTypes() err = c.loadPgTypes()
if err != nil { if err != nil {
return nil, err return err
} }
return c, nil return nil
default: default:
if err = c.processContextFreeMsg(t, r); err != nil { if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err return err
} }
} }
} }
@ -905,7 +921,7 @@ func (c *Conn) rxNotificationResponse(r *msgReader) {
c.notifications = append(c.notifications, n) c.notifications = append(c.notifications, n)
} }
func (c *Conn) startTLS() (err error) { func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
if err != nil { if err != nil {
return return
@ -920,7 +936,7 @@ func (c *Conn) startTLS() (err error) {
return ErrTLSRefused return ErrTLSRefused
} }
c.conn = tls.Client(c.conn, c.config.TLSConfig) c.conn = tls.Client(c.conn, tlsConfig)
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package pgx_test package pgx_test
import ( import (
"crypto/tls"
"fmt" "fmt"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"net" "net"
@ -184,6 +185,35 @@ func TestConnectWithMD5Password(t *testing.T) {
} }
} }
func TestConnectWithTLSFallback(t *testing.T) {
t.Parallel()
if tlsConnConfig == nil {
return
}
connConfig := *tlsConnConfig
connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure
conn, err := pgx.Connect(connConfig)
if err == nil {
t.Fatal("Expected failed connection, but succeeded")
}
connConfig.UseFallbackTLS = true
connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
conn, err = pgx.Connect(connConfig)
if err != nil {
t.Fatal("Unable to establish connection: " + err.Error())
}
err = conn.Close()
if err != nil {
t.Fatal("Unable to close connection")
}
}
func TestConnectWithConnectionRefused(t *testing.T) { func TestConnectWithConnectionRefused(t *testing.T) {
t.Parallel() t.Parallel()