mirror of https://github.com/jackc/pgx.git
Add support for libpq target_session_attrs
Generalize AcceptConnFunc into AfterConnectFunc.pull/483/head
parent
28ee40f347
commit
c552e2c028
|
@ -20,7 +20,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type AcceptConnFunc func(pgconn *PgConn) bool
|
||||
type AfterConnectFunc func(pgconn *PgConn) error
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||
type Config struct {
|
||||
|
@ -35,10 +35,10 @@ type Config struct {
|
|||
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
// AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is
|
||||
// acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This
|
||||
// AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that
|
||||
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This
|
||||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
AcceptConnFunc AcceptConnFunc
|
||||
AfterConnectFunc AfterConnectFunc
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
|
@ -92,6 +92,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||
// PGSSLROOTCERT
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
|
@ -148,17 +149,18 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
}
|
||||
|
||||
notRuntimeParams := map[string]struct{}{
|
||||
"host": struct{}{},
|
||||
"port": struct{}{},
|
||||
"database": struct{}{},
|
||||
"user": struct{}{},
|
||||
"password": struct{}{},
|
||||
"passfile": struct{}{},
|
||||
"connect_timeout": struct{}{},
|
||||
"sslmode": struct{}{},
|
||||
"sslkey": struct{}{},
|
||||
"sslcert": struct{}{},
|
||||
"sslrootcert": struct{}{},
|
||||
"host": struct{}{},
|
||||
"port": struct{}{},
|
||||
"database": struct{}{},
|
||||
"user": struct{}{},
|
||||
"password": struct{}{},
|
||||
"passfile": struct{}{},
|
||||
"connect_timeout": struct{}{},
|
||||
"sslmode": struct{}{},
|
||||
"sslkey": struct{}{},
|
||||
"sslcert": struct{}{},
|
||||
"sslrootcert": struct{}{},
|
||||
"target_session_attrs": struct{}{},
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
|
@ -225,6 +227,12 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if settings["target_session_attrs"] == "read-write" {
|
||||
config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite
|
||||
} else if settings["target_session_attrs"] != "any" {
|
||||
return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"])
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
@ -243,6 +251,8 @@ func defaultSettings() map[string]string {
|
|||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
|
@ -267,18 +277,19 @@ func defaultHost() string {
|
|||
|
||||
func addEnvSettings(settings map[string]string) {
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
|
@ -452,3 +463,31 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
|||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
return d.DialContext, nil
|
||||
}
|
||||
|
||||
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-write.
|
||||
func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error {
|
||||
pgConn.SendExec("show transaction_read_only")
|
||||
err := pgConn.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := pgConn.GetResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowFound := result.NextRow()
|
||||
if !rowFound {
|
||||
return errors.New("show transaction_read_only failed")
|
||||
}
|
||||
|
||||
if string(result.Value(0)) == "on" {
|
||||
return errors.New("read only connection")
|
||||
}
|
||||
|
||||
_, err = result.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -374,6 +374,20 @@ func TestParseConfig(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
@ -401,6 +415,9 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
|
|||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||
|
||||
// Can't test function equality, so just test that they are set or not.
|
||||
assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName)
|
||||
|
||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||
if expected.TLSConfig != nil {
|
||||
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -183,11 +184,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
return nil, err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) {
|
||||
return pgConn, nil
|
||||
if config.AfterConnectFunc != nil {
|
||||
err := config.AfterConnectFunc(pgConn)
|
||||
if err != nil {
|
||||
pgConn.NetConn.Close()
|
||||
return nil, fmt.Errorf("AfterConnectFunc: %v", err)
|
||||
}
|
||||
}
|
||||
pgConn.NetConn.Close()
|
||||
return nil, errors.New("AcceptConnFunc rejected connection")
|
||||
return pgConn, nil
|
||||
case *pgproto3.ParameterStatus:
|
||||
// handled by ReceiveMessage
|
||||
case *pgproto3.ErrorResponse:
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -188,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) {
|
|||
closeConn(t, conn)
|
||||
}
|
||||
|
||||
func TestConnectWithAcceptConnFunc(t *testing.T) {
|
||||
func TestConnectWithAfterConnectFunc(t *testing.T) {
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
||||
|
@ -199,9 +200,12 @@ func TestConnectWithAcceptConnFunc(t *testing.T) {
|
|||
}
|
||||
|
||||
acceptConnCount := 0
|
||||
config.AcceptConnFunc = func(conn *pgconn.PgConn) bool {
|
||||
config.AfterConnectFunc = func(conn *pgconn.PgConn) error {
|
||||
acceptConnCount += 1
|
||||
return acceptConnCount > 1
|
||||
if acceptConnCount < 2 {
|
||||
return errors.New("reject first conn")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append current primary config to fallbacks
|
||||
|
@ -222,6 +226,19 @@ func TestConnectWithAcceptConnFunc(t *testing.T) {
|
|||
assert.True(t, acceptConnCount > 1)
|
||||
}
|
||||
|
||||
func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
||||
config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite
|
||||
config.RuntimeParams["default_transaction_read_only"] = "on"
|
||||
|
||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
if !assert.NotNil(t, err) {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
|
Loading…
Reference in New Issue