Add AcceptConnFunc for filtering HA connections

pull/483/head
Jack Christensen 2018-12-31 11:39:22 -06:00
parent afd3583558
commit 28ee40f347
3 changed files with 51 additions and 1 deletions

View File

@ -20,6 +20,8 @@ import (
"github.com/pkg/errors"
)
type AcceptConnFunc func(pgconn *PgConn) bool
// Config is the settings used to establish a connection to a PostgreSQL server.
type Config struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
@ -32,6 +34,11 @@ type Config struct {
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig
// AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is
// acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This
// allows implementing high availability behavior such as libpq does with target_session_attrs.
AcceptConnFunc AcceptConnFunc
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a

View File

@ -137,6 +137,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if config.TLSConfig != nil {
if err := pgConn.startTLS(config.TLSConfig); err != nil {
pgConn.NetConn.Close()
return nil, err
}
}
@ -162,6 +163,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
pgConn.NetConn.Close()
return nil, err
}
@ -177,13 +179,19 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.SecretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = pgConn.rxAuthenticationX(msg); err != nil {
pgConn.NetConn.Close()
return nil, err
}
case *pgproto3.ReadyForQuery:
if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) {
return pgConn, nil
}
pgConn.NetConn.Close()
return nil, errors.New("AcceptConnFunc rejected connection")
case *pgproto3.ParameterStatus:
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
pgConn.NetConn.Close()
return nil, PgError{
Severity: msg.Severity,
Code: msg.Code,
@ -204,6 +212,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
Routine: msg.Routine,
}
default:
pgConn.NetConn.Close()
return nil, errors.New("unexpected message")
}
}

View File

@ -188,6 +188,40 @@ func TestConnectWithFallback(t *testing.T) {
closeConn(t, conn)
}
func TestConnectWithAcceptConnFunc(t *testing.T) {
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
dialCount := 0
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
dialCount += 1
return net.Dial(network, address)
}
acceptConnCount := 0
config.AcceptConnFunc = func(conn *pgconn.PgConn) bool {
acceptConnCount += 1
return acceptConnCount > 1
}
// Append current primary config to fallbacks
config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
Host: config.Host,
Port: config.Port,
TLSConfig: config.TLSConfig,
})
// Repeat fallbacks
config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.Nil(t, err)
closeConn(t, conn)
assert.True(t, dialCount > 1)
assert.True(t, acceptConnCount > 1)
}
func TestSimple(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)