mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 13:14:32 +00:00
Add fallback TLS ConnConfig option
This is in preparation for supporting libpq style SSL options.
This commit is contained in:
parent
0c0a426d18
commit
dd9d960ba3
56
conn.go
56
conn.go
@ -24,14 +24,16 @@ type DialFunc func(network, addr string) (net.Conn, error)
|
|||||||
|
|
||||||
// ConnConfig contains all the options used to establish a connection.
|
// ConnConfig contains all the options used to establish a connection.
|
||||||
type ConnConfig struct {
|
type ConnConfig struct {
|
||||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
Port uint16 // default: 5432
|
Port uint16 // default: 5432
|
||||||
Database string
|
Database string
|
||||||
User string // default: OS user name
|
User string // default: OS user name
|
||||||
Password string
|
Password string
|
||||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||||
Logger Logger
|
UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
|
||||||
Dial DialFunc
|
FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
|
||||||
|
Logger Logger
|
||||||
|
Dial DialFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
@ -140,11 +142,25 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
if c.config.Dial == nil {
|
if c.config.Dial == nil {
|
||||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = c.connect(config, network, address, config.TLSConfig)
|
||||||
|
if err != nil && config.UseFallbackTLS {
|
||||||
|
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
|
||||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||||
c.conn, err = c.config.Dial(network, address)
|
c.conn, err = c.config.Dial(network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if c != nil && err != nil {
|
if c != nil && err != nil {
|
||||||
@ -159,11 +175,11 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
c.alive = true
|
c.alive = true
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
|
|
||||||
if config.TLSConfig != nil {
|
if tlsConfig != nil {
|
||||||
c.logger.Debug("Starting TLS handshake")
|
c.logger.Debug("Starting TLS handshake")
|
||||||
if err = c.startTLS(); err != nil {
|
if err := c.startTLS(tlsConfig); err != nil {
|
||||||
c.logger.Error(fmt.Sprintf("TLS failed: %v", err))
|
c.logger.Error(fmt.Sprintf("TLS failed: %v", err))
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +192,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
msg.options["database"] = c.config.Database
|
msg.options["database"] = c.config.Database
|
||||||
}
|
}
|
||||||
if err = c.txStartupMessage(msg); err != nil {
|
if err = c.txStartupMessage(msg); err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -184,7 +200,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
var r *msgReader
|
var r *msgReader
|
||||||
t, r, err = c.rxMsg()
|
t, r, err = c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch t {
|
switch t {
|
||||||
@ -192,7 +208,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
c.rxBackendKeyData(r)
|
c.rxBackendKeyData(r)
|
||||||
case authenticationX:
|
case authenticationX:
|
||||||
if err = c.rxAuthenticationX(r); err != nil {
|
if err = c.rxAuthenticationX(r); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
case readyForQuery:
|
case readyForQuery:
|
||||||
c.rxReadyForQuery(r)
|
c.rxReadyForQuery(r)
|
||||||
@ -203,13 +219,13 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
|
|
||||||
err = c.loadPgTypes()
|
err = c.loadPgTypes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return nil
|
||||||
default:
|
default:
|
||||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -905,7 +921,7 @@ func (c *Conn) rxNotificationResponse(r *msgReader) {
|
|||||||
c.notifications = append(c.notifications, n)
|
c.notifications = append(c.notifications, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) startTLS() (err error) {
|
func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||||
err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
|
err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -920,7 +936,7 @@ func (c *Conn) startTLS() (err error) {
|
|||||||
return ErrTLSRefused
|
return ErrTLSRefused
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn = tls.Client(c.conn, c.config.TLSConfig)
|
c.conn = tls.Client(c.conn, tlsConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
30
conn_test.go
30
conn_test.go
@ -1,6 +1,7 @@
|
|||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"net"
|
"net"
|
||||||
@ -184,6 +185,35 @@ func TestConnectWithMD5Password(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectWithTLSFallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if tlsConnConfig == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connConfig := *tlsConnConfig
|
||||||
|
connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure
|
||||||
|
|
||||||
|
conn, err := pgx.Connect(connConfig)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected failed connection, but succeeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
connConfig.UseFallbackTLS = true
|
||||||
|
connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
|
conn, err = pgx.Connect(connConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unable to establish connection: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unable to close connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectWithConnectionRefused(t *testing.T) {
|
func TestConnectWithConnectionRefused(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user