diff --git a/stdlib/sql.go b/stdlib/sql.go index a0bc6182..8ad27b58 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -56,6 +56,7 @@ import ( "fmt" "io" "math" + "math/rand" "reflect" "strconv" "strings" @@ -110,18 +111,57 @@ var ( // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) -// OptionAfterConnect provide a callback for after connect. +// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will +// be used to connect, so only its immediate members should be modified. +func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { + return func(dc *connector) { + dc.BeforeConnect = bc + } +} + +// OptionAfterConnect provides a callback for after connect. func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { return func(dc *connector) { dc.AfterConnect = ac } } +// RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a +// new host becomes primary each time. This is useful to distribute connections for multi-master databases like +// CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well +// to ensure that connections are periodically rebalanced across your nodes. +func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error { + if len(connConfig.Fallbacks) == 0 { + return nil + } + + newFallbacks := append([]*pgconn.FallbackConfig(nil), + &pgconn.FallbackConfig{ + Host: connConfig.Host, + Port: connConfig.Port, + TLSConfig: connConfig.TLSConfig, + }) + newFallbacks = append(newFallbacks, connConfig.Fallbacks...) + + rand.Shuffle(len(newFallbacks), func(i, j int) { + newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i] + }) + + // Use the one that sorted last as the primary and keep the rest as the fallbacks + newPrimary := newFallbacks[len(newFallbacks)-1] + connConfig.Host = newPrimary.Host + connConfig.Port = newPrimary.Port + connConfig.TLSConfig = newPrimary.TLSConfig + connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1] + return nil +} + func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { c := connector{ - ConnConfig: config, - AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default - driver: pgxDriver, + ConnConfig: config, + BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default + AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default + driver: pgxDriver, } for _, opt := range opts { @@ -133,8 +173,9 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { type connector struct { pgx.ConnConfig - AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection - driver *Driver + BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection + AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection + driver *Driver } // Connect implement driver.Connector interface @@ -144,7 +185,13 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { conn *pgx.Conn ) - if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil { + // Create a shallow copy of the config, so that BeforeConnect can safely modify it + connConfig := c.ConnConfig + if err = c.BeforeConnect(ctx, &connConfig); err != nil { + return nil, err + } + + if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { return nil, err } @@ -152,7 +199,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil + return &Conn{conn: conn, driver: c.driver, connConfig: connConfig}, nil } // Driver implement driver.Connector interface diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 5131194c..eb08be7d 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1129,3 +1129,76 @@ func TestConnQueryRowConstraintErrors(t *testing.T) { assert.Error(t, err) }) } + +func TestOptionBeforeAfterConnect(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var beforeConnConfigs []*pgx.ConnConfig + var afterConns []*pgx.Conn + db := stdlib.OpenDB(*config, + stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error { + beforeConnConfigs = append(beforeConnConfigs, connConfig) + return nil + }), + stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + afterConns = append(afterConns, conn) + return nil + })) + defer closeDB(t, db) + + // Force it to close and reopen a new connection after each query + db.SetMaxIdleConns(0) + + _, err = db.Exec("select 1") + require.NoError(t, err) + + _, err = db.Exec("select 1") + require.NoError(t, err) + + require.Len(t, beforeConnConfigs, 2) + require.Len(t, afterConns, 2) + + // Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they + // are different objects, so can't use require.NotEqual + require.False(t, config == beforeConnConfigs[0]) + require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1]) +} + +func TestRandomizeHostOrderFunc(t *testing.T) { + config, err := pgx.ParseConfig("postgres://host1,host2,host3") + require.NoError(t, err) + + // Test that at some point we connect to all 3 hosts + hostsNotSeenYet := map[string]struct{}{ + "host1": struct{}{}, + "host2": struct{}{}, + "host3": struct{}{}, + } + + // If we don't succeed within this many iterations, something is certainly wrong + for i := 0; i < 100000; i++ { + connCopy := *config + stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy) + + delete(hostsNotSeenYet, connCopy.Host) + if len(hostsNotSeenYet) == 0 { + return + } + + hostCheckLoop: + for _, h := range []string{"host1", "host2", "host3"} { + if connCopy.Host == h { + continue + } + for _, f := range connCopy.Fallbacks { + if f.Host == h { + continue hostCheckLoop + } + } + require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy) + } + } + + require.Fail(t, "did not get all hosts as primaries after many randomizations") +}