diff --git a/pgconn/config.go b/pgconn/config.go index a07fa533..38144be7 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -20,7 +20,7 @@ import ( "github.com/pkg/errors" ) -type AcceptConnFunc func(pgconn *PgConn) bool +type AfterConnectFunc func(pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -35,10 +35,10 @@ type Config struct { Fallbacks []*FallbackConfig - // AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is - // acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This + // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that + // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. - AcceptConnFunc AcceptConnFunc + AfterConnectFunc AfterConnectFunc } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -92,6 +92,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // @@ -148,17 +149,18 @@ func ParseConfig(connString string) (*Config, error) { } notRuntimeParams := map[string]struct{}{ - "host": struct{}{}, - "port": struct{}{}, - "database": struct{}{}, - "user": struct{}{}, - "password": struct{}{}, - "passfile": struct{}{}, - "connect_timeout": struct{}{}, - "sslmode": struct{}{}, - "sslkey": struct{}{}, - "sslcert": struct{}{}, - "sslrootcert": struct{}{}, + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + "target_session_attrs": struct{}{}, } for k, v := range settings { @@ -225,6 +227,12 @@ func ParseConfig(connString string) (*Config, error) { } } + if settings["target_session_attrs"] == "read-write" { + config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite + } else if settings["target_session_attrs"] != "any" { + return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"]) + } + return config, nil } @@ -243,6 +251,8 @@ func defaultSettings() map[string]string { settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") } + settings["target_session_attrs"] = "any" + return settings } @@ -267,18 +277,19 @@ func defaultHost() string { func addEnvSettings(settings map[string]string) { nameMap := map[string]string{ - "PGHOST": "host", - "PGPORT": "port", - "PGDATABASE": "database", - "PGUSER": "user", - "PGPASSWORD": "password", - "PGPASSFILE": "passfile", - "PGAPPNAME": "application_name", - "PGCONNECT_TIMEOUT": "connect_timeout", - "PGSSLMODE": "sslmode", - "PGSSLKEY": "sslkey", - "PGSSLCERT": "sslcert", - "PGSSLROOTCERT": "sslrootcert", + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + "PGTARGETSESSIONATTRS": "target_session_attrs", } for envname, realname := range nameMap { @@ -452,3 +463,31 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { d.Timeout = time.Duration(timeout) * time.Second return d.DialContext, nil } + +// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { + pgConn.SendExec("show transaction_read_only") + err := pgConn.Flush() + if err != nil { + return err + } + + result := pgConn.GetResult() + if err != nil { + return err + } + + rowFound := result.NextRow() + if !rowFound { + return errors.New("show transaction_read_only failed") + } + + if string(result.Value(0)) == "on" { + return errors.New("read only connection") + } + + _, err = result.Close() + + return err +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 566a44f0..36f3fee2 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -374,6 +374,20 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "target_session_attrs", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, + }, + }, } for i, tt := range tests { @@ -401,6 +415,9 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ac48f870..94397759 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "encoding/hex" "errors" + "fmt" "io" "net" "strconv" @@ -183,11 +184,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, err } case *pgproto3.ReadyForQuery: - if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) { - return pgConn, nil + if config.AfterConnectFunc != nil { + err := config.AfterConnectFunc(pgConn) + if err != nil { + pgConn.NetConn.Close() + return nil, fmt.Errorf("AfterConnectFunc: %v", err) + } } - pgConn.NetConn.Close() - return nil, errors.New("AcceptConnFunc rejected connection") + return pgConn, nil case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index ad06ae7b..0dccc99f 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -188,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAcceptConnFunc(t *testing.T) { +func TestConnectWithAfterConnectFunc(t *testing.T) { config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -199,9 +200,12 @@ func TestConnectWithAcceptConnFunc(t *testing.T) { } acceptConnCount := 0 - config.AcceptConnFunc = func(conn *pgconn.PgConn) bool { + config.AfterConnectFunc = func(conn *pgconn.PgConn) error { acceptConnCount += 1 - return acceptConnCount > 1 + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil } // Append current primary config to fallbacks @@ -222,6 +226,19 @@ func TestConnectWithAcceptConnFunc(t *testing.T) { assert.True(t, acceptConnCount > 1) } +func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" + + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close() + } +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err)