mirror of https://github.com/jackc/pgx.git
Add fallback TLS ConnConfig option
This is in preparation for supporting libpq style SSL options.pull/78/head
parent
0c0a426d18
commit
dd9d960ba3
56
conn.go
56
conn.go
|
@ -24,14 +24,16 @@ type DialFunc func(network, addr string) (net.Conn, error)
|
|||
|
||||
// ConnConfig contains all the options used to establish a connection.
|
||||
type ConnConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16 // default: 5432
|
||||
Database string
|
||||
User string // default: OS user name
|
||||
Password string
|
||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||
Logger Logger
|
||||
Dial DialFunc
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16 // default: 5432
|
||||
Database string
|
||||
User string // default: OS user name
|
||||
Password string
|
||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||
UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
|
||||
FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
|
||||
Logger Logger
|
||||
Dial DialFunc
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
|
@ -140,11 +142,25 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
if c.config.Dial == nil {
|
||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||
}
|
||||
|
||||
err = c.connect(config, network, address, config.TLSConfig)
|
||||
if err != nil && config.UseFallbackTLS {
|
||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
c.conn, err = c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if c != nil && err != nil {
|
||||
|
@ -159,11 +175,11 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
c.alive = true
|
||||
c.lastActivityTime = time.Now()
|
||||
|
||||
if config.TLSConfig != nil {
|
||||
if tlsConfig != nil {
|
||||
c.logger.Debug("Starting TLS handshake")
|
||||
if err = c.startTLS(); err != nil {
|
||||
if err := c.startTLS(tlsConfig); err != nil {
|
||||
c.logger.Error(fmt.Sprintf("TLS failed: %v", err))
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,7 +192,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
msg.options["database"] = c.config.Database
|
||||
}
|
||||
if err = c.txStartupMessage(msg); err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
|
@ -184,7 +200,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
|
@ -192,7 +208,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
c.rxBackendKeyData(r)
|
||||
case authenticationX:
|
||||
if err = c.rxAuthenticationX(r); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
|
@ -203,13 +219,13 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
|
||||
err = c.loadPgTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
return nil
|
||||
default:
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -905,7 +921,7 @@ func (c *Conn) rxNotificationResponse(r *msgReader) {
|
|||
c.notifications = append(c.notifications, n)
|
||||
}
|
||||
|
||||
func (c *Conn) startTLS() (err error) {
|
||||
func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||
err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -920,7 +936,7 @@ func (c *Conn) startTLS() (err error) {
|
|||
return ErrTLSRefused
|
||||
}
|
||||
|
||||
c.conn = tls.Client(c.conn, c.config.TLSConfig)
|
||||
c.conn = tls.Client(c.conn, tlsConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
30
conn_test.go
30
conn_test.go
|
@ -1,6 +1,7 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"net"
|
||||
|
@ -184,6 +185,35 @@ func TestConnectWithMD5Password(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectWithTLSFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if tlsConnConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
connConfig := *tlsConnConfig
|
||||
connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure
|
||||
|
||||
conn, err := pgx.Connect(connConfig)
|
||||
if err == nil {
|
||||
t.Fatal("Expected failed connection, but succeeded")
|
||||
}
|
||||
|
||||
connConfig.UseFallbackTLS = true
|
||||
connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
conn, err = pgx.Connect(connConfig)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to establish connection: " + err.Error())
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Unable to close connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithConnectionRefused(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue