mirror of https://github.com/jackc/pgx.git
ResetSession hook is called before a connection is reused from pool for another query.
parent
00feeaa5c9
commit
cabb58cc40
|
@ -126,6 +126,15 @@ func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB
|
|||
}
|
||||
}
|
||||
|
||||
// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
|
||||
// connection if the connection has been used before.
|
||||
// If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
|
||||
func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB {
|
||||
return func(dc *connector) {
|
||||
dc.ResetSession = rs
|
||||
}
|
||||
}
|
||||
|
||||
// RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a
|
||||
// new host becomes primary each time. This is useful to distribute connections for multi-master databases like
|
||||
// CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well
|
||||
|
@ -159,6 +168,7 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
|||
ConnConfig: config,
|
||||
BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default
|
||||
AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default
|
||||
ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default
|
||||
driver: pgxDriver,
|
||||
}
|
||||
|
||||
|
@ -173,6 +183,7 @@ type connector struct {
|
|||
pgx.ConnConfig
|
||||
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
|
||||
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
|
||||
ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused
|
||||
driver *Driver
|
||||
}
|
||||
|
||||
|
@ -197,7 +208,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig}, nil
|
||||
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
|
||||
}
|
||||
|
||||
// Driver implement driver.Connector interface
|
||||
|
@ -272,7 +283,13 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
c := &Conn{conn: conn, driver: dc.driver, connConfig: *connConfig}
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
driver: dc.driver,
|
||||
connConfig: *connConfig,
|
||||
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
@ -291,10 +308,11 @@ func UnregisterConnConfig(connStr string) {
|
|||
}
|
||||
|
||||
type Conn struct {
|
||||
conn *pgx.Conn
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
driver *Driver
|
||||
connConfig pgx.ConnConfig
|
||||
conn *pgx.Conn
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
driver *Driver
|
||||
connConfig pgx.ConnConfig
|
||||
resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused
|
||||
}
|
||||
|
||||
// Conn returns the underlying *pgx.Conn
|
||||
|
@ -436,7 +454,8 @@ func (c *Conn) ResetSession(ctx context.Context) error {
|
|||
if c.conn.IsClosed() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
return nil
|
||||
|
||||
return c.resetSessionFunc(ctx, c.conn)
|
||||
}
|
||||
|
||||
type Stmt struct {
|
||||
|
|
|
@ -1202,3 +1202,26 @@ func TestRandomizeHostOrderFunc(t *testing.T) {
|
|||
|
||||
require.Fail(t, "did not get all hosts as primaries after many randomizations")
|
||||
}
|
||||
|
||||
func TestResetSessionHookCalled(t *testing.T) {
|
||||
var mockCalled bool
|
||||
|
||||
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
|
||||
mockCalled = true
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
defer closeDB(t, db)
|
||||
|
||||
err = db.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, mockCalled)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue