Add support for libpq target_session_attrs

Generalize AcceptConnFunc into AfterConnectFunc.
pull/483/head
Jack Christensen 2018-12-31 12:14:41 -06:00
parent 28ee40f347
commit c552e2c028
4 changed files with 111 additions and 34 deletions

View File

@ -20,7 +20,7 @@ import (
"github.com/pkg/errors"
)
type AcceptConnFunc func(pgconn *PgConn) bool
type AfterConnectFunc func(pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server.
type Config struct {
@ -35,10 +35,10 @@ type Config struct {
Fallbacks []*FallbackConfig
// AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is
// acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This
// AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This
// allows implementing high availability behavior such as libpq does with target_session_attrs.
AcceptConnFunc AcceptConnFunc
AfterConnectFunc AfterConnectFunc
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
@ -92,6 +92,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGSSLROOTCERT
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
@ -148,17 +149,18 @@ func ParseConfig(connString string) (*Config, error) {
}
notRuntimeParams := map[string]struct{}{
"host": struct{}{},
"port": struct{}{},
"database": struct{}{},
"user": struct{}{},
"password": struct{}{},
"passfile": struct{}{},
"connect_timeout": struct{}{},
"sslmode": struct{}{},
"sslkey": struct{}{},
"sslcert": struct{}{},
"sslrootcert": struct{}{},
"host": struct{}{},
"port": struct{}{},
"database": struct{}{},
"user": struct{}{},
"password": struct{}{},
"passfile": struct{}{},
"connect_timeout": struct{}{},
"sslmode": struct{}{},
"sslkey": struct{}{},
"sslcert": struct{}{},
"sslrootcert": struct{}{},
"target_session_attrs": struct{}{},
}
for k, v := range settings {
@ -225,6 +227,12 @@ func ParseConfig(connString string) (*Config, error) {
}
}
if settings["target_session_attrs"] == "read-write" {
config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" {
return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"])
}
return config, nil
}
@ -243,6 +251,8 @@ func defaultSettings() map[string]string {
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
}
settings["target_session_attrs"] = "any"
return settings
}
@ -267,18 +277,19 @@ func defaultHost() string {
func addEnvSettings(settings map[string]string) {
nameMap := map[string]string{
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLROOTCERT": "sslrootcert",
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLROOTCERT": "sslrootcert",
"PGTARGETSESSIONATTRS": "target_session_attrs",
}
for envname, realname := range nameMap {
@ -452,3 +463,31 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
d.Timeout = time.Duration(timeout) * time.Second
return d.DialContext, nil
}
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error {
pgConn.SendExec("show transaction_read_only")
err := pgConn.Flush()
if err != nil {
return err
}
result := pgConn.GetResult()
if err != nil {
return err
}
rowFound := result.NextRow()
if !rowFound {
return errors.New("show transaction_read_only failed")
}
if string(result.Value(0)) == "on" {
return errors.New("read only connection")
}
_, err = result.Close()
return err
}

View File

@ -374,6 +374,20 @@ func TestParseConfig(t *testing.T) {
},
},
},
{
name: "target_session_attrs",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite,
},
},
}
for i, tt := range tests {
@ -401,6 +415,9 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName)
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
if expected.TLSConfig != nil {
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)

View File

@ -7,6 +7,7 @@ import (
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"strconv"
@ -183,11 +184,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, err
}
case *pgproto3.ReadyForQuery:
if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) {
return pgConn, nil
if config.AfterConnectFunc != nil {
err := config.AfterConnectFunc(pgConn)
if err != nil {
pgConn.NetConn.Close()
return nil, fmt.Errorf("AfterConnectFunc: %v", err)
}
}
pgConn.NetConn.Close()
return nil, errors.New("AcceptConnFunc rejected connection")
return pgConn, nil
case *pgproto3.ParameterStatus:
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:

View File

@ -9,6 +9,7 @@ import (
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgconn"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -188,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) {
closeConn(t, conn)
}
func TestConnectWithAcceptConnFunc(t *testing.T) {
func TestConnectWithAfterConnectFunc(t *testing.T) {
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
@ -199,9 +200,12 @@ func TestConnectWithAcceptConnFunc(t *testing.T) {
}
acceptConnCount := 0
config.AcceptConnFunc = func(conn *pgconn.PgConn) bool {
config.AfterConnectFunc = func(conn *pgconn.PgConn) error {
acceptConnCount += 1
return acceptConnCount > 1
if acceptConnCount < 2 {
return errors.New("reject first conn")
}
return nil
}
// Append current primary config to fallbacks
@ -222,6 +226,19 @@ func TestConnectWithAcceptConnFunc(t *testing.T) {
assert.True(t, acceptConnCount > 1)
}
func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite
config.RuntimeParams["default_transaction_read_only"] = "on"
conn, err := pgconn.ConnectConfig(context.Background(), config)
if !assert.NotNil(t, err) {
conn.Close()
}
}
func TestSimple(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)