diff --git a/conn.go b/conn.go index bcbf88ad..6fe4ba72 100644 --- a/conn.go +++ b/conn.go @@ -72,6 +72,7 @@ type ConnConfig struct { Logger Logger LogLevel int Dial DialFunc + Timeout time.Duration RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) OnNotice NoticeHandler // Callback function called when a notice response is received. CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. @@ -259,7 +260,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) network, address := c.config.networkAddress() if c.config.Dial == nil { - c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial + c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial } if c.shouldLog(LogLevelInfo) { @@ -686,13 +687,22 @@ func ParseURI(uri string) (ConnConfig, error) { } cp.Database = strings.TrimLeft(url.Path, "/") + if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { + timeout, err := strconv.ParseInt(pgtimeout, 10, 64) + if err != nil { + return cp, err + } + cp.Timeout = time.Duration(timeout) * time.Second + } + err = configSSL(url.Query().Get("sslmode"), &cp) if err != nil { return cp, err } ignoreKeys := map[string]struct{}{ - "sslmode": {}, + "sslmode": {}, + "connect_timeout": {}, } cp.RuntimeParams = make(map[string]string) @@ -750,6 +760,12 @@ func ParseDSN(s string) (ConnConfig, error) { cp.Database = b[2] case "sslmode": sslmode = b[2] + case "connect_timeout": + t, err := strconv.ParseInt(b[2], 10, 64) + if err != nil { + return cp, err + } + cp.Timeout = time.Duration(t) * time.Second default: cp.RuntimeParams[b[1]] = b[2] } @@ -787,6 +803,7 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGPASSWORD // PGSSLMODE // PGAPPNAME +// PGCONNECT_TIMEOUT // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -822,6 +839,14 @@ func ParseEnvLibpq() (ConnConfig, error) { cc.User = os.Getenv("PGUSER") cc.Password = os.Getenv("PGPASSWORD") + if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" { + if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil { + cc.Timeout = time.Duration(timeout) * time.Second + } else { + return cc, err + } + } + sslmode := os.Getenv("PGSSLMODE") err := configSSL(sslmode, &cc) diff --git a/conn_test.go b/conn_test.go index c1cb4ebe..696a4003 100644 --- a/conn_test.go +++ b/conn_test.go @@ -567,6 +567,21 @@ func TestParseDSN(t *testing.T) { }, }, }, + { + url: "user=jack host=localhost dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Timeout: 10 * time.Second, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests { @@ -697,6 +712,21 @@ func TestParseConnectionString(t *testing.T) { }, }, }, + { + url: "postgres://jack@localhost/mydb?connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Timeout: 10 * time.Second, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", connParams: pgx.ConnConfig{ @@ -802,7 +832,7 @@ func TestParseConnectionString(t *testing.T) { } func TestParseEnvLibpq(t *testing.T) { - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"} + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} savedEnv := make(map[string]string) for _, n := range pgEnvvars { @@ -835,11 +865,12 @@ func TestParseEnvLibpq(t *testing.T) { { name: "Normal PG vars", envvars: map[string]string{ - "PGHOST": "123.123.123.123", - "PGPORT": "7777", - "PGDATABASE": "foo", - "PGUSER": "bar", - "PGPASSWORD": "baz", + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", }, config: pgx.ConnConfig{ Host: "123.123.123.123", @@ -850,6 +881,7 @@ func TestParseEnvLibpq(t *testing.T) { TLSConfig: &tls.Config{InsecureSkipVerify: true}, UseFallbackTLS: true, FallbackTLSConfig: nil, + Timeout: 10 * time.Second, RuntimeParams: map[string]string{}, }, },