mirror of https://github.com/jackc/pgx.git
Add AcceptConnFunc for filtering HA connections
parent
afd3583558
commit
28ee40f347
|
@ -20,6 +20,8 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type AcceptConnFunc func(pgconn *PgConn) bool
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
|
@ -32,6 +34,11 @@ type Config struct {
|
|||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
// AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is
|
||||
// acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This
|
||||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
AcceptConnFunc AcceptConnFunc
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
|
|
|
@ -137,6 +137,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
|
||||
if config.TLSConfig != nil {
|
||||
if err := pgConn.startTLS(config.TLSConfig); err != nil {
|
||||
pgConn.NetConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -162,6 +163,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
}
|
||||
|
||||
if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
||||
pgConn.NetConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -177,13 +179,19 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
pgConn.SecretKey = msg.SecretKey
|
||||
case *pgproto3.Authentication:
|
||||
if err = pgConn.rxAuthenticationX(msg); err != nil {
|
||||
pgConn.NetConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) {
|
||||
return pgConn, nil
|
||||
}
|
||||
pgConn.NetConn.Close()
|
||||
return nil, errors.New("AcceptConnFunc rejected connection")
|
||||
case *pgproto3.ParameterStatus:
|
||||
// handled by ReceiveMessage
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgConn.NetConn.Close()
|
||||
return nil, PgError{
|
||||
Severity: msg.Severity,
|
||||
Code: msg.Code,
|
||||
|
@ -204,6 +212,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
Routine: msg.Routine,
|
||||
}
|
||||
default:
|
||||
pgConn.NetConn.Close()
|
||||
return nil, errors.New("unexpected message")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -188,6 +188,40 @@ func TestConnectWithFallback(t *testing.T) {
|
|||
closeConn(t, conn)
|
||||
}
|
||||
|
||||
func TestConnectWithAcceptConnFunc(t *testing.T) {
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
||||
dialCount := 0
|
||||
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
dialCount += 1
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
acceptConnCount := 0
|
||||
config.AcceptConnFunc = func(conn *pgconn.PgConn) bool {
|
||||
acceptConnCount += 1
|
||||
return acceptConnCount > 1
|
||||
}
|
||||
|
||||
// Append current primary config to fallbacks
|
||||
config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
|
||||
Host: config.Host,
|
||||
Port: config.Port,
|
||||
TLSConfig: config.TLSConfig,
|
||||
})
|
||||
|
||||
// Repeat fallbacks
|
||||
config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
|
||||
|
||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
require.Nil(t, err)
|
||||
closeConn(t, conn)
|
||||
|
||||
assert.True(t, dialCount > 1)
|
||||
assert.True(t, acceptConnCount > 1)
|
||||
}
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
|
Loading…
Reference in New Issue