Merge remote-tracking branch 'pgconn/master' into v5-dev

pull/1281/head
Jack Christensen 2022-07-12 06:45:54 -05:00
commit 0f7b95c3a4
6 changed files with 165 additions and 24 deletions

View File

@ -80,12 +80,14 @@ func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error)
if err != nil {
return nil, err
}
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
if ok {
return saslContinue, nil
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
@ -93,12 +95,14 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
if err != nil {
return nil, err
}
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
if ok {
return saslFinal, nil
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLFinal:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
}
type scramClient struct {

View File

@ -99,10 +99,29 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
func isAbsolutePath(path string) bool {
isWindowsPath := func(p string) bool {
if len(p) < 3 {
return false
}
drive := p[0]
colon := p[1]
backslash := p[2]
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
return true
}
return false
}
return strings.HasPrefix(path, "/") || isWindowsPath(path)
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if strings.HasPrefix(host, "/") {
if isAbsolutePath(host) {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
@ -341,7 +360,9 @@ func ParseConfig(connString string) (*Config, error) {
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
case "standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
case "any", "prefer-standby":
case "prefer-standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
@ -772,3 +793,18 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
}
return nil
}

View File

@ -231,6 +231,18 @@ func TestParseConfig(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
name: "database url unix domain socket host on windows",
connString: "postgres:///foo?host=C:\\tmp",
config: &pgconn.Config{
User: osUserName,
Host: "C:\\tmp",
Port: 5432,
Database: "foo",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "database url dbname",
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
@ -600,13 +612,14 @@ func TestParseConfig(t *testing.T) {
name: "target_session_attrs prefer-standby",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby,
},
},
{
@ -703,6 +716,55 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
assert.NoError(t, err)
}
func TestNetworkAddress(t *testing.T) {
tests := []struct {
name string
host string
wantNet string
}{
{
name: "Default Unix socket address",
host: "/var/run/postgresql",
wantNet: "unix",
},
{
name: "Windows Unix socket address (standard drive name)",
host: "C:\\tmp",
wantNet: "unix",
},
{
name: "Windows Unix socket address (first drive name)",
host: "A:\\tmp",
wantNet: "unix",
},
{
name: "Windows Unix socket address (last drive name)",
host: "Z:\\tmp",
wantNet: "unix",
},
{
name: "Assume TCP for unknown formats",
host: "a/tmp",
wantNet: "tcp",
},
{
name: "loopback interface",
host: "localhost",
wantNet: "tcp",
},
{
name: "IP address",
host: "127.0.0.1",
wantNet: "tcp",
},
}
for i, tt := range tests {
gotNet, _ := pgconn.NetworkAddress(tt.host, 5432)
assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name)
}
}
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
if !assert.NotNil(t, expected) {
return

View File

@ -202,3 +202,20 @@ func redactURL(u *url.URL) string {
}
return u.String()
}
type NotPreferredError struct {
err error
safeToRetry bool
}
func (e *NotPreferredError) Error() string {
return fmt.Sprintf("standby server not found: %s", e.err.Error())
}
func (e *NotPreferredError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *NotPreferredError) Unwrap() error {
return e.err
}

View File

@ -2,6 +2,7 @@ package pgconn
import (
"errors"
"fmt"
"github.com/jackc/pgx/v5/pgproto3"
)
@ -87,10 +88,13 @@ func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
if err != nil {
return nil, err
}
gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue)
if ok {
return gssContinue, nil
switch m := msg.(type) {
case *pgproto3.AuthenticationGSSContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message")
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
}

View File

@ -137,9 +137,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
foundBestServer := false
var fallbackConfig *FallbackConfig
for _, fc := range fallbackConfigs {
pgConn, err = connect(ctx, config, fc)
pgConn, err = connect(ctx, config, fc, false)
if err == nil {
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
@ -153,6 +156,17 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*connectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
}
}
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
}
}
@ -176,7 +190,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
for _, fb := range fallbacks {
// skip resolve for unix sockets
if strings.HasPrefix(fb.Host, "/") {
if isAbsolutePath(fb.Host) {
configs = append(configs, &FallbackConfig{
Host: fb.Host,
Port: fb.Port,
@ -216,7 +230,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
return configs, nil
}
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
ignoreNotPreferredErr bool) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.config = config
pgConn.cleanupDone = make(chan struct{})
@ -330,6 +345,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err := config.ValidateConnect(ctx, pgConn)
if err != nil {
if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
return pgConn, nil
}
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
}