Make pgconn.ConnectError and pgconn.ParseConfigError public

fixes #1773
pull/1875/head
Jack Christensen 2024-01-12 17:52:25 -06:00
parent 44768b5a01
commit 5d26bbefd8
4 changed files with 37 additions and 35 deletions

View File

@ -237,12 +237,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString) connStringSettings, err = parseURLSettings(connString)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
} }
} else { } else {
connStringSettings, err = parseDSNSettings(connString) connStringSettings, err = parseDSNSettings(connString)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
} }
} }
} }
@ -251,7 +251,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present { if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service) serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
} }
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@ -278,7 +278,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if connectTimeoutSetting, present := settings["connect_timeout"]; present { if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
} }
config.ConnectTimeout = connectTimeout config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@ -340,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr) port, err := parsePort(portStr)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
} }
var tlsConfigs []*tls.Config var tlsConfigs []*tls.Config
@ -352,7 +352,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error var err error
tlsConfigs, err = configTLS(settings, host, options) tlsConfigs, err = configTLS(settings, host, options)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
} }
} }
@ -396,7 +396,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any": case "any":
// do nothing // do nothing
default: default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
} }
return config, nil return config, nil

View File

@ -57,22 +57,23 @@ func (pe *PgError) SQLState() string {
return pe.Code return pe.Code
} }
type connectError struct { // ConnectError is the error returned when a connection attempt fails.
config *Config type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string msg string
err error err error
} }
func (e *connectError) Error() string { func (e *ConnectError) Error() string {
sb := &strings.Builder{} sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil { if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error()) fmt.Fprintf(sb, " (%s)", e.err.Error())
} }
return sb.String() return sb.String()
} }
func (e *connectError) Unwrap() error { func (e *ConnectError) Unwrap() error {
return e.err return e.err
} }
@ -88,21 +89,22 @@ func (e *connLockError) Error() string {
return e.status return e.status
} }
type parseConfigError struct { // ParseConfigError is the error returned when a connection string cannot be parsed.
connString string type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string msg string
err error err error
} }
func (e *parseConfigError) Error() string { func (e *ParseConfigError) Error() string {
connString := redactPW(e.connString) connString := redactPW(e.ConnString)
if e.err == nil { if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
} }
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
} }
func (e *parseConfigError) Unwrap() error { func (e *ParseConfigError) Unwrap() error {
return e.err return e.err
} }

View File

@ -3,8 +3,8 @@
package pgconn package pgconn
func NewParseConfigError(conn, msg string, err error) error { func NewParseConfigError(conn, msg string, err error) error {
return &parseConfigError{ return &ParseConfigError{
connString: conn, ConnString: conn,
msg: msg, msg: msg,
err: err, err: err,
} }

View File

@ -152,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil { if err != nil {
return nil, &connectError{config: config, msg: "hostname resolving error", err: err} return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
} }
if len(fallbackConfigs) == 0 { if len(fallbackConfigs) == 0 {
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
} }
foundBestServer := false foundBestServer := false
@ -178,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer = true foundBestServer = true
break break
} else if pgerr, ok := err.(*PgError); ok { } else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr} err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
@ -189,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break break
} }
} else if cerr, ok := err.(*connectError); ok { } else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok { if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc fallbackConfig = fc
} }
@ -199,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
if !foundBestServer && fallbackConfig != nil { if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true) pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok { if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr} err = &ConnectError{Config: config, msg: "server error", err: pgerr}
} }
} }
@ -211,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn) err := config.AfterConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "AfterConnect error", err: err} return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
} }
} }
@ -283,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address) netConn, err := config.DialFunc(ctx, network, address)
if err != nil { if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
} }
pgConn.conn = netConn pgConn.conn = netConn
@ -295,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil { if err != nil {
netConn.Close() netConn.Close()
return nil, &connectError{config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
} }
pgConn.conn = nbTLSConn pgConn.conn = nbTLSConn
@ -336,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg) pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
} }
for { for {
@ -346,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok { if err, ok := err.(*PgError); ok {
return nil, err return nil, err
} }
return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -359,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password) err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err} return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
} }
case *pgproto3.AuthenticationMD5Password: case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword) err = pgConn.txPasswordMessage(digestedPassword)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err} return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
} }
case *pgproto3.AuthenticationSASL: case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms) err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err} return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
} }
case *pgproto3.AuthenticationGSS: case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth() err = pgConn.gssAuth()
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed GSS auth", err: err} return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle pgConn.status = connStatusIdle
@ -396,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil return pgConn, nil
} }
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
} }
} }
return pgConn, nil return pgConn, nil
@ -407,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
default: default:
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "received unexpected message", err: err} return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
} }
} }
} }