mirror of https://github.com/jackc/pgx.git
Merge remote-tracking branch 'pgconn/master' into v5-dev
commit
0f7b95c3a4
|
@ -80,12 +80,14 @@ func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
|
||||
if ok {
|
||||
return saslContinue, nil
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
|
@ -93,12 +95,14 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
|
||||
if ok {
|
||||
return saslFinal, nil
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLFinal:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
|
|
|
@ -99,10 +99,29 @@ type FallbackConfig struct {
|
|||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// isAbsolutePath checks if the provided value is an absolute path either
|
||||
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||
func isAbsolutePath(path string) bool {
|
||||
isWindowsPath := func(p string) bool {
|
||||
if len(p) < 3 {
|
||||
return false
|
||||
}
|
||||
drive := p[0]
|
||||
colon := p[1]
|
||||
backslash := p[2]
|
||||
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||
}
|
||||
|
||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||
// net.Dial.
|
||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
if strings.HasPrefix(host, "/") {
|
||||
if isAbsolutePath(host) {
|
||||
network = "unix"
|
||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||
} else {
|
||||
|
@ -341,7 +360,9 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||
case "standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||
case "any", "prefer-standby":
|
||||
case "prefer-standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||
case "any":
|
||||
// do nothing
|
||||
default:
|
||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
|
@ -772,3 +793,18 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=prefer-standby.
|
||||
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -231,6 +231,18 @@ func TestParseConfig(t *testing.T) {
|
|||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url unix domain socket host on windows",
|
||||
connString: "postgres:///foo?host=C:\\tmp",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "C:\\tmp",
|
||||
Port: 5432,
|
||||
Database: "foo",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url dbname",
|
||||
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
|
||||
|
@ -600,13 +612,14 @@ func TestParseConfig(t *testing.T) {
|
|||
name: "target_session_attrs prefer-standby",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -703,6 +716,55 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetworkAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantNet string
|
||||
}{
|
||||
{
|
||||
name: "Default Unix socket address",
|
||||
host: "/var/run/postgresql",
|
||||
wantNet: "unix",
|
||||
},
|
||||
{
|
||||
name: "Windows Unix socket address (standard drive name)",
|
||||
host: "C:\\tmp",
|
||||
wantNet: "unix",
|
||||
},
|
||||
{
|
||||
name: "Windows Unix socket address (first drive name)",
|
||||
host: "A:\\tmp",
|
||||
wantNet: "unix",
|
||||
},
|
||||
{
|
||||
name: "Windows Unix socket address (last drive name)",
|
||||
host: "Z:\\tmp",
|
||||
wantNet: "unix",
|
||||
},
|
||||
{
|
||||
name: "Assume TCP for unknown formats",
|
||||
host: "a/tmp",
|
||||
wantNet: "tcp",
|
||||
},
|
||||
{
|
||||
name: "loopback interface",
|
||||
host: "localhost",
|
||||
wantNet: "tcp",
|
||||
},
|
||||
{
|
||||
name: "IP address",
|
||||
host: "127.0.0.1",
|
||||
wantNet: "tcp",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
gotNet, _ := pgconn.NetworkAddress(tt.host, 5432)
|
||||
|
||||
assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name)
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||
if !assert.NotNil(t, expected) {
|
||||
return
|
||||
|
|
|
@ -202,3 +202,20 @@ func redactURL(u *url.URL) string {
|
|||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
type NotPreferredError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Error() string {
|
||||
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package pgconn
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
@ -87,10 +88,13 @@ func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue)
|
||||
if ok {
|
||||
return gssContinue, nil
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationGSSContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message")
|
||||
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||
}
|
||||
|
|
|
@ -137,9 +137,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
|
||||
}
|
||||
|
||||
foundBestServer := false
|
||||
var fallbackConfig *FallbackConfig
|
||||
for _, fc := range fallbackConfigs {
|
||||
pgConn, err = connect(ctx, config, fc)
|
||||
pgConn, err = connect(ctx, config, fc, false)
|
||||
if err == nil {
|
||||
foundBestServer = true
|
||||
break
|
||||
} else if pgerr, ok := err.(*PgError); ok {
|
||||
err = &connectError{config: config, msg: "server error", err: pgerr}
|
||||
|
@ -153,6 +156,17 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
|
||||
break
|
||||
}
|
||||
} else if cerr, ok := err.(*connectError); ok {
|
||||
if _, ok := cerr.err.(*NotPreferredError); ok {
|
||||
fallbackConfig = fc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBestServer && fallbackConfig != nil {
|
||||
pgConn, err = connect(ctx, config, fallbackConfig, true)
|
||||
if pgerr, ok := err.(*PgError); ok {
|
||||
err = &connectError{config: config, msg: "server error", err: pgerr}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,7 +190,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
|
|||
|
||||
for _, fb := range fallbacks {
|
||||
// skip resolve for unix sockets
|
||||
if strings.HasPrefix(fb.Host, "/") {
|
||||
if isAbsolutePath(fb.Host) {
|
||||
configs = append(configs, &FallbackConfig{
|
||||
Host: fb.Host,
|
||||
Port: fb.Port,
|
||||
|
@ -216,7 +230,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
|
|||
return configs, nil
|
||||
}
|
||||
|
||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
|
||||
ignoreNotPreferredErr bool) (*PgConn, error) {
|
||||
pgConn := new(PgConn)
|
||||
pgConn.config = config
|
||||
pgConn.cleanupDone = make(chan struct{})
|
||||
|
@ -330,6 +345,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
|
||||
err := config.ValidateConnect(ctx, pgConn)
|
||||
if err != nil {
|
||||
if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
|
||||
return pgConn, nil
|
||||
}
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue