diff --git a/pgconn/config.go b/pgconn/config.go index d2001dc5..a07fa533 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -20,6 +20,8 @@ import ( "github.com/pkg/errors" ) +type AcceptConnFunc func(pgconn *PgConn) bool + // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) @@ -32,6 +34,11 @@ type Config struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig + + // AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is + // acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This + // allows implementing high availability behavior such as libpq does with target_session_attrs. + AcceptConnFunc AcceptConnFunc } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 09860eb2..ac48f870 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -137,6 +137,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { + pgConn.NetConn.Close() return nil, err } } @@ -162,6 +163,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.NetConn.Close() return nil, err } @@ -177,13 +179,19 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { + pgConn.NetConn.Close() return nil, err } case *pgproto3.ReadyForQuery: - return pgConn, nil + if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) { + return pgConn, nil + } + pgConn.NetConn.Close() + return nil, errors.New("AcceptConnFunc rejected connection") case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: + pgConn.NetConn.Close() return nil, PgError{ Severity: msg.Severity, Code: msg.Code, @@ -204,6 +212,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: + pgConn.NetConn.Close() return nil, errors.New("unexpected message") } } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index d53bbc09..ad06ae7b 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -188,6 +188,40 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } +func TestConnectWithAcceptConnFunc(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount += 1 + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.AcceptConnFunc = func(conn *pgconn.PgConn) bool { + acceptConnCount += 1 + return acceptConnCount > 1 + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err)