Failed connection attempts include all errors

A single Connect("connstring") may actually make multiple connection
requests due to TLS or HA configuration. Previously, when all attempts
failed only the last error was returned. This could be confusing.
Now details of all failed attempts are included.

For example, the following connection string:

host=localhost,127.0.0.1,foo.invalid port=1,2,3

Will now return an error like the following:

failed to connect to `user=postgres database=pgx_test`:
	lookup foo.invalid: no such host
	[::1]:1 (localhost): dial error: dial tcp [::1]:1: connect: connection refused
	127.0.0.1:1 (localhost): dial error: dial tcp 127.0.0.1:1: connect: connection refused
	127.0.0.1:2 (127.0.0.1): dial error: dial tcp 127.0.0.1:2: connect: connection refused

https://github.com/jackc/pgx/issues/1929
pull/2011/head
Jack Christensen 2024-05-11 14:19:46 -05:00
parent 48cdd7bab0
commit 8db971660e
4 changed files with 189 additions and 125 deletions

View File

@ -118,6 +118,14 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS
}
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).

View File

@ -62,23 +62,37 @@ func (pe *PgError) SQLState() string {
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string
err error
}
func (e *ConnectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
details := e.err.Error()
if strings.Contains(details, "\n") {
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
}
return sb.String()
}
func (e *ConnectError) Unwrap() error {
return e.err
}
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}

View File

@ -133,15 +133,46 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio
//
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
// if all attempts fail the last error is returned.
func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) {
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
if !config.createdByParseConfig {
panic("config must be created by ParseConfig")
}
var allErrors []error
connectConfigs, errs := buildConnectOneConfigs(ctx, config)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
}
if len(connectConfigs) == 0 {
return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))}
}
pgConn, errs := connectPreferred(ctx, config, connectConfigs)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)}
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)}
}
}
return pgConn, nil
}
// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a
// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain
// values if some hosts were successfully resolved and others were not.
func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) {
// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
@ -151,95 +182,28 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
},
}
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
var configs []*connectOneConfig
foundBestServer := false
var fallbackConfig *FallbackConfig
for i, fc := range fallbackConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err = connect(ctx, config, fc, false)
if err == nil {
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil ||
pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
}
}
var allErrors []error
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}
if err != nil {
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}
return pgConn, nil
}
func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
var configs []*FallbackConfig
var lookupErrors []error
for _, fb := range fallbacks {
for _, fb := range fallbackConfigs {
// skip resolve for unix sockets
if isAbsolutePath(fb.Host) {
configs = append(configs, &FallbackConfig{
Host: fb.Host,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
network, address := NetworkAddress(fb.Host, fb.Port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
continue
}
ips, err := lookupFn(ctx, fb.Host)
ips, err := config.LookupFunc(ctx, fb.Host)
if err != nil {
lookupErrors = append(lookupErrors, err)
allErrors = append(allErrors, err)
continue
}
@ -248,33 +212,91 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
if err == nil {
port, err := strconv.ParseUint(splitPort, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)}
}
configs = append(configs, &FallbackConfig{
Host: splitIP,
Port: uint16(port),
TLSConfig: fb.TLSConfig,
network, address := NetworkAddress(splitIP, uint16(port))
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
} else {
configs = append(configs, &FallbackConfig{
Host: ip,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
network, address := NetworkAddress(ip, fb.Port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
}
}
}
// See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all
// errors are reported.
if len(configs) == 0 && len(lookupErrors) > 0 {
return nil, lookupErrors[0]
}
return configs, nil
return configs, allErrors
}
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in
// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If
// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful.
func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) {
octx := ctx
var allErrors []error
var fallbackConnectOneConfig *connectOneConfig
for i, c := range connectOneConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err := connectOne(ctx, config, c, false)
if pgConn != nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
var pgErr *PgError
if errors.As(err, &pgErr) {
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
return nil, allErrors
}
}
var npErr *NotPreferredError
if errors.As(err, &npErr) {
fallbackConnectOneConfig = c
}
}
if fallbackConnectOneConfig != nil {
pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true)
if err == nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
}
return nil, allErrors
}
// connectOne makes one connection attempt to a single host.
func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig,
ignoreNotPreferredErr bool,
) (*PgConn, error) {
pgConn := new(PgConn)
@ -283,20 +305,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.customData = make(map[string]any)
var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil {
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
newPerDialConnectError := func(msg string, err error) *perDialConnectError {
err = normalizeTimeoutError(ctx, err)
e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)}
return e
}
if fallbackConfig.TLSConfig != nil {
pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address)
if err != nil {
return nil, newPerDialConnectError("dial error", err)
}
if connectConfig.tlsConfig != nil {
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
pgConn.contextWatcher.Watch(ctx)
tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig)
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
return nil, newPerDialConnectError("tls error", err)
}
pgConn.conn = tlsConn
@ -337,7 +365,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
return nil, newPerDialConnectError("failed to write startup message", err)
}
for {
@ -345,9 +373,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err != nil {
pgConn.conn.Close()
if err, ok := err.(*PgError); ok {
return nil, err
return nil, newPerDialConnectError("server error", err)
}
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
return nil, newPerDialConnectError("failed to receive message", err)
}
switch msg := msg.(type) {
@ -360,26 +388,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
return nil, newPerDialConnectError("failed SASL auth", err)
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
return nil, newPerDialConnectError("failed GSS auth", err)
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
@ -397,7 +425,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil
}
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
return nil, newPerDialConnectError("ValidateConnect failed", err)
}
}
return pgConn, nil
@ -405,10 +433,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
pgConn.conn.Close()
return nil, ErrorResponseToPgError(msg)
return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg))
default:
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
return nil, newPerDialConnectError("received unexpected message", err)
}
}
}

View File

@ -357,10 +357,8 @@ func TestConnectInvalidUser(t *testing.T) {
_, err = pgconn.ConnectConfig(ctx, config)
require.Error(t, err)
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
if !ok {
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
}
var pgErr *pgconn.PgError
require.ErrorAs(t, err, &pgErr)
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
}
@ -529,6 +527,22 @@ func TestConnectWithFallback(t *testing.T) {
closeConn(t, conn)
}
func TestConnectFailsWithResolveFailureAndFailedConnectionAttempts(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, err := pgconn.Connect(ctx, "host=localhost,127.0.0.1,foo.invalid port=1,2,3 sslmode=disable")
require.Error(t, err)
require.Nil(t, conn)
require.ErrorContains(t, err, "lookup foo.invalid: no such host")
// Not testing the entire string as depending on IPv4 or IPv6 support localhost may resolve to 127.0.0.1 or ::1.
require.ErrorContains(t, err, ":1 (localhost): dial error:")
require.ErrorContains(t, err, ":2 (127.0.0.1): dial error:")
}
func TestConnectWithValidateConnect(t *testing.T) {
t.Parallel()