diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index 37b602d6..6ca9e337 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -80,12 +80,14 @@ func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) if err != nil { return nil, err } - saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) - if ok { - return saslContinue, nil + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) } - return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") + return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) } func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { @@ -93,12 +95,14 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { if err != nil { return nil, err } - saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) - if ok { - return saslFinal, nil + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLFinal: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) } - return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") + return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) } type scramClient struct { diff --git a/pgconn/config.go b/pgconn/config.go index bfec11d4..e4a25cb8 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -99,10 +99,29 @@ type FallbackConfig struct { TLSConfig *tls.Config // nil disables TLS } +// isAbsolutePath checks if the provided value is an absolute path either +// beginning with a forward slash (as on Linux-based systems) or with a capital +// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). +func isAbsolutePath(path string) bool { + isWindowsPath := func(p string) bool { + if len(p) < 3 { + return false + } + drive := p[0] + colon := p[1] + backslash := p[2] + if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { + return true + } + return false + } + return strings.HasPrefix(path, "/") || isWindowsPath(path) +} + // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with // net.Dial. func NetworkAddress(host string, port uint16) (network, address string) { - if strings.HasPrefix(host, "/") { + if isAbsolutePath(host) { network = "unix" address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { @@ -341,7 +360,9 @@ func ParseConfig(connString string) (*Config, error) { config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary case "standby": config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby - case "any", "prefer-standby": + case "prefer-standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby + case "any": // do nothing default: return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} @@ -772,3 +793,18 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon return nil } + +// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=prefer-standby. +func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) != "t" { + return &NotPreferredError{err: errors.New("server is not in hot standby mode")} + } + + return nil +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 40db7bc2..b9201d72 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -231,6 +231,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url unix domain socket host on windows", + connString: "postgres:///foo?host=C:\\tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "C:\\tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "database url dbname", connString: "postgres://localhost/?dbname=foo&sslmode=disable", @@ -600,13 +612,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs prefer-standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, }, }, { @@ -703,6 +716,55 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { assert.NoError(t, err) } +func TestNetworkAddress(t *testing.T) { + tests := []struct { + name string + host string + wantNet string + }{ + { + name: "Default Unix socket address", + host: "/var/run/postgresql", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (standard drive name)", + host: "C:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (first drive name)", + host: "A:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (last drive name)", + host: "Z:\\tmp", + wantNet: "unix", + }, + { + name: "Assume TCP for unknown formats", + host: "a/tmp", + wantNet: "tcp", + }, + { + name: "loopback interface", + host: "localhost", + wantNet: "tcp", + }, + { + name: "IP address", + host: "127.0.0.1", + wantNet: "tcp", + }, + } + for i, tt := range tests { + gotNet, _ := pgconn.NetworkAddress(tt.host, 5432) + + assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name) + } +} + func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return diff --git a/pgconn/errors.go b/pgconn/errors.go index 030f7e0a..4254535e 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -202,3 +202,20 @@ func redactURL(u *url.URL) string { } return u.String() } + +type NotPreferredError struct { + err error + safeToRetry bool +} + +func (e *NotPreferredError) Error() string { + return fmt.Sprintf("standby server not found: %s", e.err.Error()) +} + +func (e *NotPreferredError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *NotPreferredError) Unwrap() error { + return e.err +} diff --git a/pgconn/krb5.go b/pgconn/krb5.go index a4bca01f..969675fd 100644 --- a/pgconn/krb5.go +++ b/pgconn/krb5.go @@ -2,6 +2,7 @@ package pgconn import ( "errors" + "fmt" "github.com/jackc/pgx/v5/pgproto3" ) @@ -87,10 +88,13 @@ func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { if err != nil { return nil, err } - gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue) - if ok { - return gssContinue, nil + + switch m := msg.(type) { + case *pgproto3.AuthenticationGSSContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) } - return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message") + return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ab2e2df9..dcd9dc85 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -137,9 +137,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} } + foundBestServer := false + var fallbackConfig *FallbackConfig for _, fc := range fallbackConfigs { - pgConn, err = connect(ctx, config, fc) + pgConn, err = connect(ctx, config, fc, false) if err == nil { + foundBestServer = true break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} @@ -153,6 +156,17 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } + } else if cerr, ok := err.(*connectError); ok { + if _, ok := cerr.err.(*NotPreferredError); ok { + fallbackConfig = fc + } + } + } + + if !foundBestServer && fallbackConfig != nil { + pgConn, err = connect(ctx, config, fallbackConfig, true) + if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} } } @@ -176,7 +190,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba for _, fb := range fallbacks { // skip resolve for unix sockets - if strings.HasPrefix(fb.Host, "/") { + if isAbsolutePath(fb.Host) { configs = append(configs, &FallbackConfig{ Host: fb.Host, Port: fb.Port, @@ -216,7 +230,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba return configs, nil } -func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, + ignoreNotPreferredErr bool) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) @@ -330,6 +345,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.ValidateConnect(ctx, pgConn) if err != nil { + if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok { + return pgConn, nil + } pgConn.conn.Close() return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} }