mirror of https://github.com/jackc/pgx.git
stdlib: add OptionBeforeConnect and randomizer
Fixes https://github.com/jackc/pgconn/issues/71pull/1010/head
parent
e722ca608c
commit
a8020a21e8
|
@ -56,6 +56,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -110,16 +111,55 @@ var (
|
||||||
// OptionOpenDB options for configuring the driver when opening a new db pool.
|
// OptionOpenDB options for configuring the driver when opening a new db pool.
|
||||||
type OptionOpenDB func(*connector)
|
type OptionOpenDB func(*connector)
|
||||||
|
|
||||||
// OptionAfterConnect provide a callback for after connect.
|
// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
|
||||||
|
// be used to connect, so only its immediate members should be modified.
|
||||||
|
func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
|
||||||
|
return func(dc *connector) {
|
||||||
|
dc.BeforeConnect = bc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionAfterConnect provides a callback for after connect.
|
||||||
func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
|
func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
|
||||||
return func(dc *connector) {
|
return func(dc *connector) {
|
||||||
dc.AfterConnect = ac
|
dc.AfterConnect = ac
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a
|
||||||
|
// new host becomes primary each time. This is useful to distribute connections for multi-master databases like
|
||||||
|
// CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well
|
||||||
|
// to ensure that connections are periodically rebalanced across your nodes.
|
||||||
|
func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error {
|
||||||
|
if len(connConfig.Fallbacks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newFallbacks := append([]*pgconn.FallbackConfig(nil),
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: connConfig.Host,
|
||||||
|
Port: connConfig.Port,
|
||||||
|
TLSConfig: connConfig.TLSConfig,
|
||||||
|
})
|
||||||
|
newFallbacks = append(newFallbacks, connConfig.Fallbacks...)
|
||||||
|
|
||||||
|
rand.Shuffle(len(newFallbacks), func(i, j int) {
|
||||||
|
newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i]
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use the one that sorted last as the primary and keep the rest as the fallbacks
|
||||||
|
newPrimary := newFallbacks[len(newFallbacks)-1]
|
||||||
|
connConfig.Host = newPrimary.Host
|
||||||
|
connConfig.Port = newPrimary.Port
|
||||||
|
connConfig.TLSConfig = newPrimary.TLSConfig
|
||||||
|
connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
||||||
c := connector{
|
c := connector{
|
||||||
ConnConfig: config,
|
ConnConfig: config,
|
||||||
|
BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default
|
||||||
AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default
|
AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default
|
||||||
driver: pgxDriver,
|
driver: pgxDriver,
|
||||||
}
|
}
|
||||||
|
@ -133,7 +173,8 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
||||||
|
|
||||||
type connector struct {
|
type connector struct {
|
||||||
pgx.ConnConfig
|
pgx.ConnConfig
|
||||||
AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection
|
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
|
||||||
|
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
|
||||||
driver *Driver
|
driver *Driver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +185,13 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||||
conn *pgx.Conn
|
conn *pgx.Conn
|
||||||
)
|
)
|
||||||
|
|
||||||
if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil {
|
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
|
||||||
|
connConfig := c.ConnConfig
|
||||||
|
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +199,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil
|
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Driver implement driver.Connector interface
|
// Driver implement driver.Connector interface
|
||||||
|
|
|
@ -1129,3 +1129,76 @@ func TestConnQueryRowConstraintErrors(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOptionBeforeAfterConnect(t *testing.T) {
|
||||||
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var beforeConnConfigs []*pgx.ConnConfig
|
||||||
|
var afterConns []*pgx.Conn
|
||||||
|
db := stdlib.OpenDB(*config,
|
||||||
|
stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
|
||||||
|
beforeConnConfigs = append(beforeConnConfigs, connConfig)
|
||||||
|
return nil
|
||||||
|
}),
|
||||||
|
stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
|
||||||
|
afterConns = append(afterConns, conn)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
// Force it to close and reopen a new connection after each query
|
||||||
|
db.SetMaxIdleConns(0)
|
||||||
|
|
||||||
|
_, err = db.Exec("select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = db.Exec("select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Len(t, beforeConnConfigs, 2)
|
||||||
|
require.Len(t, afterConns, 2)
|
||||||
|
|
||||||
|
// Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they
|
||||||
|
// are different objects, so can't use require.NotEqual
|
||||||
|
require.False(t, config == beforeConnConfigs[0])
|
||||||
|
require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRandomizeHostOrderFunc(t *testing.T) {
|
||||||
|
config, err := pgx.ParseConfig("postgres://host1,host2,host3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test that at some point we connect to all 3 hosts
|
||||||
|
hostsNotSeenYet := map[string]struct{}{
|
||||||
|
"host1": struct{}{},
|
||||||
|
"host2": struct{}{},
|
||||||
|
"host3": struct{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we don't succeed within this many iterations, something is certainly wrong
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
connCopy := *config
|
||||||
|
stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
|
||||||
|
|
||||||
|
delete(hostsNotSeenYet, connCopy.Host)
|
||||||
|
if len(hostsNotSeenYet) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hostCheckLoop:
|
||||||
|
for _, h := range []string{"host1", "host2", "host3"} {
|
||||||
|
if connCopy.Host == h {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, f := range connCopy.Fallbacks {
|
||||||
|
if f.Host == h {
|
||||||
|
continue hostCheckLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Fail(t, "did not get all hosts as primaries after many randomizations")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue