mirror of https://github.com/jackc/pgx.git
Merge remote-tracking branch 'pgconn/master' into v5-dev
commit
0f7b95c3a4
|
@ -80,12 +80,14 @@ func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
|
switch m := msg.(type) {
|
||||||
if ok {
|
case *pgproto3.AuthenticationSASLContinue:
|
||||||
return saslContinue, nil
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
|
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||||
|
@ -93,12 +95,14 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
|
switch m := msg.(type) {
|
||||||
if ok {
|
case *pgproto3.AuthenticationSASLFinal:
|
||||||
return saslFinal, nil
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
|
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
type scramClient struct {
|
type scramClient struct {
|
||||||
|
|
|
@ -99,10 +99,29 @@ type FallbackConfig struct {
|
||||||
TLSConfig *tls.Config // nil disables TLS
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isAbsolutePath checks if the provided value is an absolute path either
|
||||||
|
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||||
|
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||||
|
func isAbsolutePath(path string) bool {
|
||||||
|
isWindowsPath := func(p string) bool {
|
||||||
|
if len(p) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
drive := p[0]
|
||||||
|
colon := p[1]
|
||||||
|
backslash := p[2]
|
||||||
|
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||||
|
}
|
||||||
|
|
||||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||||
// net.Dial.
|
// net.Dial.
|
||||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||||
if strings.HasPrefix(host, "/") {
|
if isAbsolutePath(host) {
|
||||||
network = "unix"
|
network = "unix"
|
||||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||||
} else {
|
} else {
|
||||||
|
@ -341,7 +360,9 @@ func ParseConfig(connString string) (*Config, error) {
|
||||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||||
case "standby":
|
case "standby":
|
||||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||||
case "any", "prefer-standby":
|
case "prefer-standby":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||||
|
case "any":
|
||||||
// do nothing
|
// do nothing
|
||||||
default:
|
default:
|
||||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||||
|
@ -772,3 +793,18 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=prefer-standby.
|
||||||
|
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "t" {
|
||||||
|
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -231,6 +231,18 @@ func TestParseConfig(t *testing.T) {
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "database url unix domain socket host on windows",
|
||||||
|
connString: "postgres:///foo?host=C:\\tmp",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "C:\\tmp",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "foo",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "database url dbname",
|
name: "database url dbname",
|
||||||
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
|
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
|
||||||
|
@ -600,13 +612,14 @@ func TestParseConfig(t *testing.T) {
|
||||||
name: "target_session_attrs prefer-standby",
|
name: "target_session_attrs prefer-standby",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 5432,
|
Port: 5432,
|
||||||
Database: "mydb",
|
Database: "mydb",
|
||||||
TLSConfig: nil,
|
TLSConfig: nil,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -703,6 +716,55 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNetworkAddress(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
wantNet string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Default Unix socket address",
|
||||||
|
host: "/var/run/postgresql",
|
||||||
|
wantNet: "unix",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Windows Unix socket address (standard drive name)",
|
||||||
|
host: "C:\\tmp",
|
||||||
|
wantNet: "unix",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Windows Unix socket address (first drive name)",
|
||||||
|
host: "A:\\tmp",
|
||||||
|
wantNet: "unix",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Windows Unix socket address (last drive name)",
|
||||||
|
host: "Z:\\tmp",
|
||||||
|
wantNet: "unix",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Assume TCP for unknown formats",
|
||||||
|
host: "a/tmp",
|
||||||
|
wantNet: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "loopback interface",
|
||||||
|
host: "localhost",
|
||||||
|
wantNet: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IP address",
|
||||||
|
host: "127.0.0.1",
|
||||||
|
wantNet: "tcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
gotNet, _ := pgconn.NetworkAddress(tt.host, 5432)
|
||||||
|
|
||||||
|
assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||||
if !assert.NotNil(t, expected) {
|
if !assert.NotNil(t, expected) {
|
||||||
return
|
return
|
||||||
|
|
|
@ -202,3 +202,20 @@ func redactURL(u *url.URL) string {
|
||||||
}
|
}
|
||||||
return u.String()
|
return u.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NotPreferredError struct {
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) Error() string {
|
||||||
|
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package pgconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
)
|
)
|
||||||
|
@ -87,10 +88,13 @@ func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue)
|
|
||||||
if ok {
|
switch m := msg.(type) {
|
||||||
return gssContinue, nil
|
case *pgproto3.AuthenticationGSSContinue:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message")
|
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,9 +137,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
||||||
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
|
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
foundBestServer := false
|
||||||
|
var fallbackConfig *FallbackConfig
|
||||||
for _, fc := range fallbackConfigs {
|
for _, fc := range fallbackConfigs {
|
||||||
pgConn, err = connect(ctx, config, fc)
|
pgConn, err = connect(ctx, config, fc, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
foundBestServer = true
|
||||||
break
|
break
|
||||||
} else if pgerr, ok := err.(*PgError); ok {
|
} else if pgerr, ok := err.(*PgError); ok {
|
||||||
err = &connectError{config: config, msg: "server error", err: pgerr}
|
err = &connectError{config: config, msg: "server error", err: pgerr}
|
||||||
|
@ -153,6 +156,17 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
||||||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
|
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
} else if cerr, ok := err.(*connectError); ok {
|
||||||
|
if _, ok := cerr.err.(*NotPreferredError); ok {
|
||||||
|
fallbackConfig = fc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundBestServer && fallbackConfig != nil {
|
||||||
|
pgConn, err = connect(ctx, config, fallbackConfig, true)
|
||||||
|
if pgerr, ok := err.(*PgError); ok {
|
||||||
|
err = &connectError{config: config, msg: "server error", err: pgerr}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +190,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
|
||||||
|
|
||||||
for _, fb := range fallbacks {
|
for _, fb := range fallbacks {
|
||||||
// skip resolve for unix sockets
|
// skip resolve for unix sockets
|
||||||
if strings.HasPrefix(fb.Host, "/") {
|
if isAbsolutePath(fb.Host) {
|
||||||
configs = append(configs, &FallbackConfig{
|
configs = append(configs, &FallbackConfig{
|
||||||
Host: fb.Host,
|
Host: fb.Host,
|
||||||
Port: fb.Port,
|
Port: fb.Port,
|
||||||
|
@ -216,7 +230,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
|
||||||
return configs, nil
|
return configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
|
||||||
|
ignoreNotPreferredErr bool) (*PgConn, error) {
|
||||||
pgConn := new(PgConn)
|
pgConn := new(PgConn)
|
||||||
pgConn.config = config
|
pgConn.config = config
|
||||||
pgConn.cleanupDone = make(chan struct{})
|
pgConn.cleanupDone = make(chan struct{})
|
||||||
|
@ -330,6 +345,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
|
|
||||||
err := config.ValidateConnect(ctx, pgConn)
|
err := config.ValidateConnect(ctx, pgConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
|
||||||
|
return pgConn, nil
|
||||||
|
}
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
|
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue